[英]performance of a multi-class classifier using the 3 highest probabilities
I have a pandas dataframe as below 我有一个pandas数据帧如下
predictions.head()
Out[22]:
A B C D E G H L N \
0 0.718363 0.5 0.403466 0.5 0.5 0.458989 0.5 0.850190 0.620878
1 0.677776 0.5 0.366128 0.5 0.5 0.042405 0.5 0.894200 0.510644
2 0.682019 0.5 0.074347 0.5 0.5 0.562217 0.5 0.417786 0.539949
3 0.482981 0.5 0.065436 0.5 0.5 0.112383 0.5 0.743659 0.604382
4 0.700207 0.5 0.515825 0.5 0.5 0.078089 0.5 0.437839 0.249892
P R S U V LABEL
0 0.182169 0.483631 0.432915 0.328495 0.5 A
1 0.015789 0.523462 0.547838 0.691239 0.5 L
2 0.799223 0.603212 0.620806 0.335204 0.5 G
3 0.246766 0.399070 0.341081 0.229407 0.5 P
4 0.064734 0.822834 0.769277 0.512239 0.5 U
Each row is a the prediction probability of the different classes (columns). 每行是不同类(列)的预测概率。 The last column is the label (correct class).
最后一列是标签(正确的类)。
I would like to evaluate the performances of the classifiers allowing 2 errors. 我想评估分类器的性能,允许2个错误。 What I mean is that if one of the highest 3 probabilities is the correct label I consider the prediction correct.
我的意思是,如果最高3个概率之一是正确的标签,我认为预测是正确的。 Is there a smart way to do it in scikit-learn?
在scikit-learn中有一种聪明的方法吗?
Try this approach: 试试这种方法:
In [57]: x = df.drop('LABEL',1).T.apply(lambda x: x.nlargest(3).index).T
In [58]: x
Out[58]:
0 1 2
0 L A N
1 L U A
2 P A S
3 L N B
4 R S A
In [59]: x.eq(df.LABEL, axis=0).any(1)
Out[59]:
0 True
1 True
2 False
3 False
4 False
dtype: bool
similar solution, which uses one transpose
less: 类似的解决方案,它使用一个
transpose
:
In [66]: x = df.drop('LABEL',1).T.apply(lambda x: x.nlargest(3).index)
In [67]: x
Out[67]:
0 1 2 3 4
0 L L P L R
1 A U A N S
2 N A S B A
In [68]: x.eq(df.LABEL).any()
Out[68]:
0 True
1 True
2 False
3 False
4 False
dtype: bool
Source DF: 来源DF:
In [70]: df
Out[70]:
A B C D E G H L N P R S U V LABEL
0 0.718363 0.5 0.403466 0.5 0.5 0.458989 0.5 0.850190 0.620878 0.182169 0.483631 0.432915 0.328495 0.5 A
1 0.677776 0.5 0.366128 0.5 0.5 0.042405 0.5 0.894200 0.510644 0.015789 0.523462 0.547838 0.691239 0.5 L
2 0.682019 0.5 0.074347 0.5 0.5 0.562217 0.5 0.417786 0.539949 0.799223 0.603212 0.620806 0.335204 0.5 G
3 0.482981 0.5 0.065436 0.5 0.5 0.112383 0.5 0.743659 0.604382 0.246766 0.399070 0.341081 0.229407 0.5 P
4 0.700207 0.5 0.515825 0.5 0.5 0.078089 0.5 0.437839 0.249892 0.064734 0.822834 0.769277 0.512239 0.5 U
UPDATE: trying to reproduce the error (from comments): 更新:尝试重现错误(来自评论):
In [81]: df
Out[81]:
a b c d e LABEL
0 1 2 3 4 5 c
1 3 4 5 6 7 d
In [82]: x = df.drop('LABEL',1).T.apply(lambda x: x.nlargest(3).index)
In [83]: x
Out[83]:
0 1
0 e e
1 d d
2 c c
In [84]: x.eq(df.LABEL).any()
Out[84]:
0 True
1 True
dtype: bool
PS I'm using Pandas 0.23.0 PS我正在使用Pandas 0.23.0
If performance is important use numpy.argsort
with remove last column by iloc
: 如果性能是重要的用途
numpy.argsort
通过删除最后一列iloc
:
print (np.argsort(-df.iloc[:, :-1].values, axis=1)[:,:3])
[[ 7 0 8]
[ 7 12 0]
[ 9 0 11]
[ 7 8 1]
[10 11 0]]
v = df.columns[np.argsort(-df.iloc[:, :-1].values, axis=1)[:,:3]]
print (v)
Index([['L', 'A', 'N'], ['L', 'U', 'A'], ['P', 'A', 'S'], ['L', 'N', 'B'],
['R', 'S', 'A']],
dtype='object')
a = pd.DataFrame(v).eq(df['LABEL'], axis=0).any(axis=1)
print (a)
0 True
1 True
2 False
3 False
4 False
dtype: bool
Thanks, @Maxu for another similar solution with numpy.argpartition
: 谢谢@Maxu与
numpy.argpartition
另一个类似的解决方案:
v = df.columns[np.argpartition(-df.iloc[:, :-1].values, 3, axis=1)[:,:3]]
Sample data: 样本数据:
df = pd.DataFrame({'A': [0.718363, 0.677776, 0.6820189999999999, 0.48298100000000005, 0.700207], 'B': [0.5, 0.5, 0.5, 0.5, 0.5], 'C': [0.403466, 0.366128, 0.074347, 0.06543600000000001, 0.515825], 'D': [0.5, 0.5, 0.5, 0.5, 0.5], 'E': [0.5, 0.5, 0.5, 0.5, 0.5], 'G': [0.45898900000000004, 0.042405, 0.562217, 0.112383, 0.07808899999999999], 'H': [0.5, 0.5, 0.5, 0.5, 0.5], 'L': [0.85019, 0.8942, 0.417786, 0.7436590000000001, 0.43783900000000003], 'N': [0.6208779999999999, 0.510644, 0.539949, 0.604382, 0.249892], 'P': [0.182169, 0.015788999999999997, 0.7992229999999999, 0.24676599999999999, 0.064734], 'R': [0.48363100000000003, 0.523462, 0.603212, 0.39907, 0.8228340000000001], 'S': [0.43291499999999994, 0.547838, 0.6208060000000001, 0.34108099999999997, 0.769277], 'U': [0.328495, 0.691239, 0.335204, 0.22940700000000003, 0.512239], 'V': [0.5, 0.5, 0.5, 0.5, 0.5], 'LABEL': ['A', 'L', 'G', 'P', 'U']})
print (df)
A B C D E G H L N \
0 0.718363 0.5 0.403466 0.5 0.5 0.458989 0.5 0.850190 0.620878
1 0.677776 0.5 0.366128 0.5 0.5 0.042405 0.5 0.894200 0.510644
2 0.682019 0.5 0.074347 0.5 0.5 0.562217 0.5 0.417786 0.539949
3 0.482981 0.5 0.065436 0.5 0.5 0.112383 0.5 0.743659 0.604382
4 0.700207 0.5 0.515825 0.5 0.5 0.078089 0.5 0.437839 0.249892
P R S U V LABEL
0 0.182169 0.483631 0.432915 0.328495 0.5 A
1 0.015789 0.523462 0.547838 0.691239 0.5 L
2 0.799223 0.603212 0.620806 0.335204 0.5 G
3 0.246766 0.399070 0.341081 0.229407 0.5 P
4 0.064734 0.822834 0.769277 0.512239 0.5 U
I can't think of a solution in sklearn
so here's one in pandas 我想不出
sklearn
中的解决方案,所以这里是熊猫中的一个
# Data
predictions
Out[]:
A B C D E G H L N P R S U V LABEL
0 0.718363 0.5 0.403466 0.5 0.5 0.458989 0.5 0.850190 0.620878 0.182169 0.483631 0.432915 0.328495 0.5 A
1 0.677776 0.5 0.366128 0.5 0.5 0.042405 0.5 0.894200 0.510644 0.015789 0.523462 0.547838 0.691239 0.5 L
2 0.682019 0.5 0.074347 0.5 0.5 0.562217 0.5 0.417786 0.539949 0.799223 0.603212 0.620806 0.335204 0.5 G
3 0.482981 0.5 0.065436 0.5 0.5 0.112383 0.5 0.743659 0.604382 0.246766 0.399070 0.341081 0.229407 0.5 P
4 0.700207 0.5 0.515825 0.5 0.5 0.078089 0.5 0.437839 0.249892 0.064734 0.822834 0.769277 0.512239 0.5 U
# Check if the label is in the top 3 (one line solution)
predictions.apply(lambda row: row['LABEL'] in list(row.drop('LABEL').sort_values().tail(3).index), axis=1)
Out[]:
0 True
1 True
2 False
3 False
4 False
Here is what is happening: 以下是发生的事情:
# List the top 3 results:
predictions.apply(lambda row: list(row.drop('LABEL').sort_values().tail(3).index), axis=1)
Out[]:
0 [N, A, L]
1 [A, U, L]
2 [S, A, P]
3 [V, N, L]
4 [A, S, R]
# Then check if the 'LABEL' is inside this list
You could ask this question on Cross Validated as they will use sklearn extensively 您可以在Cross Validated上提出这个问题,因为他们会广泛使用sklearn
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.