[英]keras-gcn fit model ValueError
I'm using this library to create a model to learn graphs.我正在使用这个库创建一个模型来学习图形。 Here is the code (from repository):
这是代码(来自存储库):
import numpy as np
from keras_gcn.backend import keras
from keras_gcn import GraphConv
# feature matrix
input_data = np.array([[[0, 1, 2],
[2, 3, 4],
[4, 5, 6],
[7, 7, 8]]])
# adjacency matrix
input_edge = np.array([[[1, 1, 1, 0],
[1, 1, 0, 0],
[1, 0, 1, 0],
[0, 0, 0, 1]]])
labels = np.array([[[1],
[0],
[1],
[0]]])
data_layer = keras.layers.Input(shape=(None, 3), name='Input-Data')
edge_layer = keras.layers.Input(shape=(None, None), dtype='int32', name='Input-Edge')
conv_layer = GraphConv(units=4, step_num=1, kernel_initializer='ones',
bias_initializer='ones', name='GraphConv')([data_layer, edge_layer])
model = keras.models.Model(inputs=[data_layer, edge_layer], outputs=conv_layer)
model.compile(optimizer='adam', loss='mae', metrics=['mae'])
model.fit([input_data, input_edge], labels)
However, when I run the code I get the following error:但是,当我运行代码时,出现以下错误:
ValueError: Error when checking target: expected GraphConv to have 3 dimensions, but got array with shape (4, 1)
while the shape of labels is (1, 4, 1)
而标签的形状是
(1, 4, 1)
You should encode your labels using onehot-encoder, something like the following:您应该使用 onehot-encoder 对标签进行编码,如下所示:
lables = np.array([[[0, 1],
[1, 0],
[0, 1],
[1, 0]]])
Also number of units in GraphConv
layer should be equal to the number of unique labels which is 2 in your case.此外,
GraphConv
层中的单元数应等于唯一标签的数量,在您的情况下为 2。
I think the issue is mismatch between the shapes of your edge_layer and data_layer.我认为问题是 edge_layer 和 data_layer 的形状不匹配。
When you use the function keras.layers.Input
you're giving data_layer a shape of shape=(None, 3)
and then you're giving edge_layer
a shape of shape=(None, None)
当你使用函数
keras.layers.Input
时,你给 data_layer 一个shape=(None, 3)
的形状,然后你给edge_layer
一个shape=(None, None)
的形状
Match the shapes and let me know how it goes.匹配形状,让我知道结果如何。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.