線性回歸(linear regression)的目標是找出一條直線,使得所有點到直線的距離(y 軸方向)之平方和為最小。我們可以使用 scikit-learn 的 linear_model.LinearRegression 來幫我們進行線性回歸,以下範例將會隨機的產生一些資料點來示範:
import matplotlib.pyplot as plt import numpy as np from sklearn.linear_model import LinearRegression x_train = np.random.rand(1000) * 10 y_train = 3 * x_train + 2 + np.random.randn(1000) * 0.1 x_test = np.random.rand(100) * 10 y_test = 3 * x_test + 2 + np.random.randn(100) * 0.1 model = LinearRegression() model.fit(x_train[:, np.newaxis], y_train) print(model.coef_[0], model.intercept_) y_pred = model.predict(x_test[:, np.newaxis]) y_pred_2 = model.coef_ * x_test + model.intercept_ print('MAE (by model.predict):', np.mean(np.abs(y_test - y_pred))) print('MAE (by model.coef_ * x_test + model.intercept_):', np.mean(np.abs(y_test - y_pred_2))) x_train_2 = np.hstack([x_train[:, np.newaxis], np.ones((x_train.size, 1))]) (coef, intercept), _, _, _ = np.linalg.lstsq(x_train_2, y_train, rcond=None) print(coef, intercept) y_pred_2 = coef * x_test + intercept print('MAE (by np.linalg.lstsq):', np.mean(np.abs(y_test - y_pred_2)))在上述範例中:
- 進行訓練時,x 的 shape 必須是「(資料點數, 每筆輸入資料的維度)」,因此需要新增一個軸。身為訓練目標的 y,則因為此例中剛好只有單變數,所以可以不必新增一個軸。
- 訓練完畢後,model.coef_ 的內容會是每個輸入維度的係數,model.intercept_ 則是常數項。
- 由於範例中的資料是亂數產生,所以每次執行所得到的係數可能不太一樣。
- 除了使用 sklearn 以外,也同時示範了使用 np.linalg.lstsq 來求取結果,以便讓各位更容易了解內部的演算方式。
如果要利用 LinearRegression 來擬合多項式,那麼你可以建立一個包含 x 的各個次方的矩陣,然後用該矩陣來對 y 做線性回歸。以下範例是一個擬合三次多項式的範例:
import matplotlib.pyplot as plt import numpy as np from sklearn.linear_model import LinearRegression x = np.linspace(-2, 2, 1000) y_raw = x ** 3 - 2 * x ** 2 + 3 * x - 4 y = y_raw + np.random.randn(1000) x2 = np.hstack([x[:, np.newaxis] ** 3, x[:, np.newaxis] ** 2, x[:, np.newaxis]]) model = LinearRegression() model.fit(x2, y[:, np.newaxis]) print(model.coef_, model.intercept_) y_fit = model.predict(x2) plt.plot(x, y, '.', label='Random data') plt.plot(x, y_raw, '-', label='Raw line') plt.plot(x, y_fit, '-', label='Fitted line') plt.legend() plt.show()在上述範例中:
- 只展示了訓練部分。
- 因為 LinearRegression 會自動幫你加入常數項,因此你不必像使用 np.linalg.lstsq 時一樣去自己加入。
- 由於範例中的資料依然是亂數產生,所以每次執行所得到的係數可能不太一樣。