你可能記得,在高斯貝氏分類當中,我們有個範例是二維平面上的第一和第三象限的點是一個類別,第二和第四象限的點是另外一個類別,每個類別各自分成了兩群,導致套用高斯貝氏分類的效果很差。而如果我們可以把每個類別,用多個高斯分布來表示呢?這個就是 sklearn.mixture 的 GaussianMixture(高斯混合模型,GMM)可以辦到的事情,如下:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.mixture import GaussianMixture

X = np.random.randn(1000, 2)
y = np.logical_xor(X[:, 0] > 0, X[:, 1] > 0).astype(int)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

n_classes = np.unique(y_train)
gmms = {}
for c in n_classes:
	gmms[c] = GaussianMixture(n_components=2)
	gmms[c].fit(X_train[y_train==c])

logps = np.column_stack([gmms[c].score_samples(X_test) for c in n_classes])
pred = n_classes[np.argmax(logps, axis=1)]

print(f'Accuracy (GMM): {100*np.mean(pred==y_test):.2f}%')

在上述範例中: