在實務上,資料很有可能不是完美的,特徵缺失就是其中一類問題。比方說,有某個預測健康狀況的模型,需要測量四肢的血壓,但是可能有人剛好骨折打石膏,因此其中幾肢的血壓無法測量。而依據資料中缺值的欄位或比率分布等狀況,則各自有不同的方法適合來處理,包含但不限於:

上述的「很低」、「幾乎」等形容詞,可能需要根據資料狀況等方面,來決定較具體的標準。

以下範例,會示範當隨機缺值比率從 1~10% 時,補特定值與使用 KNN Imputer 的模型辨識效果差異:

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_wine
from sklearn.impute import KNNImputer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

np.random.seed(0)

dataset = load_wine()
X_train, X_test, y_train, y_test = train_test_split(
	dataset.data, dataset.target, test_size=0.2, random_state=42
)
print('Training data shapes:', X_train.shape, y_train.shape)
print('Test data shapes:', X_test.shape, y_test.shape)

acc_fill = []
acc_knn = []
missing_rates = np.arange(0.01, 0.11, 0.01)
for rate in missing_rates:

	X_train_miss = X_train.copy()
	X_test_miss = X_test.copy()

	mask_train = np.random.rand(*X_train_miss.shape) < rate
	mask_test = np.random.rand(*X_test_miss.shape) < rate

	X_train_miss[mask_train] = np.nan
	X_test_miss[mask_test] = np.nan

	# Fill by fix val
	X_train_fill = np.nan_to_num(X_train_miss, nan=0)
	X_test_fill = np.nan_to_num(X_test_miss, nan=0)

	# Fill by mean
	# col_mean = np.nanmean(X_train_miss, axis=0)
	# X_train_fill = np.where(np.isnan(X_train_miss), col_mean, X_train_miss)
	# X_test_fill = np.where(np.isnan(X_test_miss), col_mean, X_test_miss)

	model = LogisticRegression(max_iter=500)
	model.fit(X_train_fill, y_train)
	pred = model.predict(X_test_fill)

	acc_fill.append(accuracy_score(y_test, pred) * 100)

	# Impute by KNN
	imputer = KNNImputer(n_neighbors=5)

	X_train_knn = imputer.fit_transform(X_train_miss)
	X_test_knn = imputer.transform(X_test_miss)

	model = LogisticRegression(max_iter=500)
	model.fit(X_train_knn, y_train)
	pred = model.predict(X_test_knn)

	acc_knn.append(accuracy_score(y_test, pred) * 100)

plt.plot(missing_rates * 100, acc_fill, '.-', label='Impute by fix val')
plt.plot(missing_rates * 100, acc_knn, '.-', label='Impute by KNN')
plt.xlabel("Missing rate (%)")
plt.ylabel("Accuracy (%)")
plt.legend()

plt.show()

在上述範例中: