Neaya~

笔记、记录、总结

混淆矩阵的理解

摘要:机器学习混淆矩阵,模型评估

以手写数字识别为例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""
# @Time : 2020/9/7
# @Author : Jimou Chen
"""
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_digits
from sklearn.preprocessing import StandardScaler # 减去平均值再除以方差
from sklearn.metrics import classification_report, confusion_matrix

if __name__ == '__main__':

digits_data = load_digits()
x_data = digits_data.data
y_data = digits_data.target

# 对数据进行标准化
sc = StandardScaler()
x_data = sc.fit_transform(x_data)
# 切分数据
x_train, x_test, y_train, y_test = train_test_split(x_data, y_data)
# 建模
model = MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=100)
model.fit(x_train, y_train)

# 预测
prediction = model.predict(x_test)
# 评估
print(classification_report(prediction, y_test))
print(confusion_matrix(y_test, prediction))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
              precision    recall  f1-score   support

0 1.00 1.00 1.00 44
1 1.00 0.96 0.98 50
2 0.98 0.98 0.98 48
3 0.97 0.97 0.97 35
4 0.98 0.98 0.98 42
5 0.96 1.00 0.98 48
6 1.00 1.00 1.00 43
7 1.00 0.98 0.99 45
8 0.96 1.00 0.98 49
9 0.98 0.96 0.97 46

accuracy 0.98 450
macro avg 0.98 0.98 0.98 450
weighted avg 0.98 0.98 0.98 450

[[44 0 0 0 0 0 0 0 0 0]
[ 0 48 0 0 0 0 0 0 0 0]
[ 0 1 47 0 0 0 0 0 0 0]
[ 0 0 0 34 0 0 0 1 0 0]
[ 0 0 0 0 41 0 0 0 0 1]
[ 0 0 0 0 1 48 0 0 0 1]
[ 0 0 0 0 0 0 43 0 0 0]
[ 0 0 0 0 0 0 0 44 0 0]
[ 0 1 1 0 0 0 0 0 49 0]
[ 0 0 0 1 0 0 0 0 0 44]]

Process finished with exit code 0

confusion_matrix理解

  • 如下
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    [[44  0  0  0  0  0  0  0  0  0]
    [ 0 48 0 0 0 0 0 0 0 0]
    [ 0 1 47 0 0 0 0 0 0 0]
    [ 0 0 0 34 0 0 0 1 0 0]
    [ 0 0 0 0 41 0 0 0 0 1]
    [ 0 0 0 0 1 48 0 0 0 1]
    [ 0 0 0 0 0 0 43 0 0 0]
    [ 0 0 0 0 0 0 0 44 0 0]
    [ 0 1 1 0 0 0 0 0 49 0]
    [ 0 0 0 1 0 0 0 0 0 44]]
  • 对角线越大越好, 最理想的情况是只有对角线有值
  • 其他地方出现值代表该分类被识别成其他的类别
  • 比如上面第0、1行很完美,说明都识别正确了
  • 但是第2行对角线有个1,说明有一个图片本来是2的,却识别成了1
  • 其他同理
Welcome to reward