简体   繁体   English

在 Keras 中使用自定义精度度量时出错

[英]Error when using custom precision metric in Keras

from keras.models import Sequential
from keras.layers import Dense, Dropout
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.pyplot import figure

model = Sequential()
model.add(Dense(90, input_dim=900, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(90, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

m = keras.metrics.Precision(class_id=1)
# example data of suitable dimension, to offer MRE to SO
X_train = np.eye(900)
Y_train = np.ones((900, 1))

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=[m])
model.fit(X_train, Y_train, epochs=50, batch_size=1000)

When I use model.compile(loss='binary_crossentropy, optimizer='adam', metrics=['Precision'])当我使用model.compile(loss='binary_crossentropy, optimizer='adam', metrics=['Precision'])

it works, but when I use Precision(class_id=1) (regardless of whether I substitute it as a variable), I get它有效,但是当我使用 Precision(class_id=1) (无论我是否将其替换为变量)时,我得到

ValueError: slice index 1 of dimension 1 out of bounds. 
for '{{node strided_slice_1}} = StridedSlice[Index=DT_INT32, T=DT_FLOAT, begin_mask=0, ellipsis_mask=1, end_mask=0, 
new_axis_mask=0, shrink_axis_mask=2](Cast_1, strided_slice_1/stack, strided_slice_1/stack_1, strided_slice_1/stack_2)' 
with input shapes: [?,1], [2], [2], [2] and with computed 
input tensors: input[1] = <0 1>, input[2] = <0 2>, input[3] = <1 1>.

I don't know what any of this stuff means.我不知道这些东西是什么意思。 Slice of WHAT is out of bounds? WHAT 的切片超出范围? (I defined X_train and Y_train, of course, and they work when I just write metric=['Precision']). (当然,我定义了 X_train 和 Y_train,当我只写 metric=['Precision'] 时它们就起作用了)。 FYI I'm doing this in SageMaker, if it makes any difference.仅供参考,我在 SageMaker 中执行此操作,如果它有什么不同的话。 There is no other code, so if I am failing to define some config thing, I don't know about that.没有其他代码,所以如果我没有定义一些配置的东西,我不知道。

As far as I can tell (im not confident), your data doesn't seem to be long enough?据我所知(我不自信),您的数据似乎不够长? Like i think its having an issue trying to slice your data because it cant find an index 1 in the first dimension, meaning there is 1 or 0 items of data.就像我认为它在尝试对数据进行切片时遇到问题,因为它在第一维中找不到索引 1,这意味着有 1 或 0 项数据。

Try checking your data and see if what you expect is there.试着检查你的数据,看看你所期望的是否在那里。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM