简体   繁体   中英

keras-gcn fit model ValueError

I'm using this library to create a model to learn graphs. Here is the code (from repository):

import numpy as np

from keras_gcn.backend import keras
from keras_gcn import GraphConv

# feature matrix
input_data = np.array([[[0, 1, 2],
                        [2, 3, 4],
                        [4, 5, 6],
                        [7, 7, 8]]])

# adjacency matrix
input_edge = np.array([[[1, 1, 1, 0],
                        [1, 1, 0, 0],
                        [1, 0, 1, 0],
                        [0, 0, 0, 1]]])

labels = np.array([[[1],
                    [0],
                    [1],
                    [0]]])

data_layer = keras.layers.Input(shape=(None, 3), name='Input-Data')
edge_layer = keras.layers.Input(shape=(None, None), dtype='int32', name='Input-Edge')
conv_layer = GraphConv(units=4, step_num=1, kernel_initializer='ones', 
                       bias_initializer='ones', name='GraphConv')([data_layer, edge_layer])
model = keras.models.Model(inputs=[data_layer, edge_layer], outputs=conv_layer)
model.compile(optimizer='adam', loss='mae', metrics=['mae'])

model.fit([input_data, input_edge], labels)

However, when I run the code I get the following error:

ValueError: Error when checking target: expected GraphConv to have 3 dimensions, but got array with shape (4, 1)

while the shape of labels is (1, 4, 1)

You should encode your labels using onehot-encoder, something like the following:

lables = np.array([[[0, 1],
                    [1, 0],
                    [0, 1],
                    [1, 0]]])

Also number of units in GraphConv layer should be equal to the number of unique labels which is 2 in your case.

I think the issue is mismatch between the shapes of your edge_layer and data_layer.

When you use the function keras.layers.Input you're giving data_layer a shape of shape=(None, 3) and then you're giving edge_layer a shape of shape=(None, None)

Match the shapes and let me know how it goes.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM