[英]Tensorflow argmax() for TFLearn?
我用MNIST数据集的tflearn
预测手写的数字。
一切正常,但是我的标签为one_hot 。 tflearn
是否有一个函数,与argmax()
一样?
您只需执行以下操作即可:
pred = model.predict(test_data)
print([ np.where(r==1)[0][0] for r in np.round(pred) ])
最好。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.