貝氏定理的公式如下,其中 x 代表模型輸入(也就是 features),y 代表模型輸出:
P(y|x) = P(x|y)P(y) / p(x)
而在這之中,只要特徵固定了,p(x) 就是固定的常數,因此可以將公式變更為以下:
P(y|x) ∝ P(x|y)P(y)
而如果你的特徵有多個維度,並假設它們之間是獨立的,則又可以改寫為如下的樣子:
P(y|x) ∝ P(y) Πi P(xi|y)
因此,模型的預測結果,也就是能讓 P(y|x) 有最大值的 y,可以寫成:
ŷ = argmaxy P(y) Πi P(xi|y)
而高斯貝氏分類器,就是將公式中的 P(xi|y) 以高斯分布來表示的分類器。我們使用 scikit-learn 的 naive_bayes.GaussianNB 來幫我們執行高斯貝氏分類演算法,使用的資料集是 Wine Data Set,該資料集在 scikit-learn 有內建一版,你可以直接透過 import 相關函式來使用,不必自己下載:
import numpy as np from sklearn.datasets import load_wine from sklearn.model_selection import train_test_split from sklearn.naive_bayes import GaussianNB dataset = load_wine() print('Data shapes:', dataset.data.shape, dataset.target.shape) X_train, X_test, y_train, y_test = train_test_split(dataset.data, dataset.target) classes = np.unique(y_train) print('Training data info:', X_train.shape, y_train.shape, classes) print('Test data shapes:', X_test.shape, y_test.shape) model = GaussianNB() model.fit(X_train, y_train) pred = model.predict(X_test) print('Accuracy (sklearn): {:.2f}%'.format(100*np.mean(pred==y_test))) n_class = classes.size mean_train = np.zeros((n_class, X_train.shape[1])) sigma_train = np.zeros((n_class, X_train.shape[1])) log_class_probs = np.zeros(n_class) for c_idx in range(n_class): idx = np.where(y_train == c_idx)[0] mean_train[c_idx] = np.mean(X_train[idx], axis=0) sigma_train[c_idx] = np.var(X_train[idx], axis=0) log_class_probs[c_idx] = np.mean(y_train == c_idx) predictions = np.zeros((y_test.shape[0], n_class)) for i, p in enumerate(X_test): for c_idx in range(n_class): log_p_x_given_y = np.sum( -0.5 * np.log(2 * 3.14 * sigma_train[c_idx]) - \ (p - mean_train[c_idx]) ** 2 / (2 * sigma_train[c_idx]) ) predictions[i][c_idx] = log_class_probs[c_idx] + log_p_x_given_y predictions = np.argmax(predictions, axis=1) print('Accuracy (self-implemented): {:.2f}%'.format(100*np.mean(predictions==y_test)))在上述範例中:
- 訓練完畢後,model.theta_ 的內容會是每個類別的每個特徵的平均,model.var_ 或 model.sigma_ 則是對應的 variance。
- 由於範例中的資料是亂數切分,所以每次執行所得到的結果會不太一樣。
- 除了使用 sklearn 以外,也同時示範了如何依照公式來實作,以便讓各位更容易了解內部的演算方式。