[英]Error when using custom precision metric in Keras
from keras.models import Sequential
from keras.layers import Dense, Dropout
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.pyplot import figure
model = Sequential()
model.add(Dense(90, input_dim=900, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(90, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
m = keras.metrics.Precision(class_id=1)
# example data of suitable dimension, to offer MRE to SO
X_train = np.eye(900)
Y_train = np.ones((900, 1))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=[m])
model.fit(X_train, Y_train, epochs=50, batch_size=1000)
When I use model.compile(loss='binary_crossentropy, optimizer='adam', metrics=['Precision'])
当我使用
model.compile(loss='binary_crossentropy, optimizer='adam', metrics=['Precision'])
it works, but when I use Precision(class_id=1) (regardless of whether I substitute it as a variable), I get它有效,但是当我使用 Precision(class_id=1) (无论我是否将其替换为变量)时,我得到
ValueError: slice index 1 of dimension 1 out of bounds.
for '{{node strided_slice_1}} = StridedSlice[Index=DT_INT32, T=DT_FLOAT, begin_mask=0, ellipsis_mask=1, end_mask=0,
new_axis_mask=0, shrink_axis_mask=2](Cast_1, strided_slice_1/stack, strided_slice_1/stack_1, strided_slice_1/stack_2)'
with input shapes: [?,1], [2], [2], [2] and with computed
input tensors: input[1] = <0 1>, input[2] = <0 2>, input[3] = <1 1>.
I don't know what any of this stuff means.我不知道这些东西是什么意思。 Slice of WHAT is out of bounds?
WHAT 的切片超出范围? (I defined X_train and Y_train, of course, and they work when I just write metric=['Precision']).
(当然,我定义了 X_train 和 Y_train,当我只写 metric=['Precision'] 时它们就起作用了)。 FYI I'm doing this in SageMaker, if it makes any difference.
仅供参考,我在 SageMaker 中执行此操作,如果它有什么不同的话。 There is no other code, so if I am failing to define some config thing, I don't know about that.
没有其他代码,所以如果我没有定义一些配置的东西,我不知道。
As far as I can tell (im not confident), your data doesn't seem to be long enough?据我所知(我不自信),您的数据似乎不够长? Like i think its having an issue trying to slice your data because it cant find an index 1 in the first dimension, meaning there is 1 or 0 items of data.
就像我认为它在尝试对数据进行切片时遇到问题,因为它在第一维中找不到索引 1,这意味着有 1 或 0 项数据。
Try checking your data and see if what you expect is there.试着检查你的数据,看看你所期望的是否在那里。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.