繁体   English   中英

即使输出层有任何大小,keras 神经网络如何找到要属性的类?

[英]How does keras neural network find the class to attribute even if the output layer has any size?

我有一个带有二进制类(真或假)的数据样本。 神经网络赋予每个类权重,最大值将决定属性类。 但是为什么即使输出层没有适当数量的神经元,keras 也能工作? (= 班级数 = 2 在我的情况下,0 或 1)。

import keras
from model import *

X_train, X_test, y_train, y_test = train_test_split(df_features, df_labels, test_size=0.25, random_state=10)

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(len(X_test.columns),)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(128, activation='softmax') # Shouldn't be two here ?
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# len(y_train.columns) == 1
history = model.fit(X_train, y_train, epochs=100, validation_split=0.25)

scores = model.evaluate(X_test, y_test, verbose=0)

print(model.metrics_names)
print('scores=', scores)

假设:它在最后添加了一个隐式层,或者它可能忽略了一些神经元,或者完全是别的什么?

编辑:添加数据

>>> print(y_train)
[0 0 0 ... 0 1 0]

>>> print(y_test)
      Class
1424      0
3150      1
2149      0
1700      0
4330      0
4200      0
# etc, ~1000 entries
>>> print('len(y_train)=', len(y_train))
len(y_train)= 2678
>>> print('len(y_test)=', len(y_test))
len(y_test)= 893

我相信问题在于您的损失sparse_cartegorical_crossentropy如何工作。 这种损失(与categorical_crossentropy相反)假设 y_actual 将作为标签编码格式而不是单热编码格式提供 意思是如果您要预测 5 个类,则 y_actual 数组提供为[0,2,4,1,2,2,3,3,1...]其中一维数组中的每个值代表一个5 个可能的班级中的班级编号。

让我们直接从tf2 文档中查看有关此损失的独立使用的示例 -

y_true = [1, 2] #class number from 0 to 2
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] #3 class classification output
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
assert loss.shape == (2,)
loss.numpy()
[0.0513, 2.253]

在您的情况下,这意味着当您的模型返回 128 维输出时,它假定此分类问题中有 128 个类。 但是,由于损失是sparse_categorical_crossentropy ,它等待接收 0-127 之间的单个数字,然后将其用于计算其误差。

由于您总是在所有情况下都给它 0 或 1,因此它假定样本所属的实际类是 128 个类中的 0 类或 1 类,而没有其他类。 因此,它的代码运行但它是错误的,因为它不是从y_train (或y_test )作为二进制类读取单个数字,而是假定它属于 128 个其他类中的一个类。

print(y_train)
[0 0 0 ... 0 1 0]

#The first 0 here, is being considered as one class out of 128 other classes. 
#The code would still work if u changed that to say 105 instead of 0.
#Similarly for all the other 0s and 1s. 

希望这是有道理的。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM