簡體   English   中英

Keras,使用 model.predict 訪問帶有 spektral GCN 的中間層輸出時的行為不一致

[英]Keras, Inconsistent behavior when using model.predict for accessing intermediate layers output with spektral GCN

我正在嘗試訪問圖形卷積網絡 (GCN) 中間層的輸出,並且 model.predict 正在為輸入值拋出 InvalidArgument 錯誤,因為 model.fit 對相同的輸入工作正常。

這是我的代碼,它使用由spektral庫提供的OGB 中的“CORA”引文數據集,為圖卷積網絡提供算法和示例。 我的代碼基於同一個庫中的示例之一, here

from spektral.datasets import citation
from spektral.layers import GraphConv
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dropout, Dense
import numpy as np

A, X, y, train_mask, val_mask, test_mask = citation.load_data('cora')

At = A.transpose()

N = A.shape[0]
F = X.shape[-1]
n_classes = y.shape[-1]

X_in = Input(shape=(F, ))
A_in = Input((N, ), sparse=True)
X_1 = GraphConv(16, 'relu', name="layer1")([X_in, A_in])
X_1 = Dropout(0.5, name="layer2")(X_1)
X_2 = GraphConv(n_classes, 'softmax', name="output")([X_1, A_in])
model = Model(inputs=[X_in, A_in], outputs=X_2)

A = GraphConv.preprocess(A).astype('f4')
At = GraphConv.preprocess(At).astype('f4')

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              weighted_metrics=['acc'])
model.summary()

# Prepare data
X = X.toarray()
A = A.astype('f4')
At = At.astype('f4')
validation_data = ([X, A], y, val_mask)

# Train model
model.fit([X, A], 
          y,
          sample_weight=train_mask,
          validation_data=validation_data,
          epochs=1,
          batch_size=N,
          shuffle=False
)

# Access intemediate layers of model
layer_name = 'layer2'
intermediate_layer_model = Model(inputs=model.input,
                                 outputs=model.get_layer(layer_name).output)

model_input = [X,A]
intermediate_output = intermediate_layer_model.predict(model_input)
print("\n\nIntermediate_output=",intermediate_output,"\n\n")

這是錯誤消息:

Traceback (most recent call last):
  File "PLGcn_example4_stackflow_debug.py", line 53, in <module>
    intermediate_output = intermediate_layer_model.predict(model_input)
  File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 130, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1599, in predict
    tmp_batch_outputs = predict_function(iterator)
  File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 846, in _call
    return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access
  File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1848, in _filtered_call
    cancellation_manager=cancellation_manager)
  File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1924, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 550, in call
    ctx=ctx)
  File "/home/mansoor4/.local/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError:  Cannot multiply A and B because inner dimension does not match: 2708 vs. 32.  Did you forget a transpose?  Dimensions of A: [32, 2708).  Dimensions of B: [32,16]
         [[node functional_3/layer1/SparseTensorDenseMatMul/SparseTensorDenseMatMul (defined at /home/mansoor4/.local/lib/python3.7/site-packages/spektral/layers/ops/matmul.py:33) ]] [Op:__inference_predict_function_22928]

Errors may have originated from an input operation.
Input Source operations connected to node functional_3/layer1/SparseTensorDenseMatMul/SparseTensorDenseMatMul:
 stack (defined at PLGcn_example4_stackflow_debug.py:53)
 functional_3/layer1/MatMul (defined at /home/mansoor4/.local/lib/python3.7/site-packages/spektral/layers/ops/matmul.py:45)

Function call stack:
predict_function

該錯誤消息與乘法的內部維度不匹配有關。 我嘗試使用諸如 model_input = [X, At] 之類的輸入的轉發來解決問題,但仍然面臨相同的錯誤。

我是 Keras 和 Spektral 的新手。 我在 stackflow 上搜索了相關帖子並嘗試了很多可能性,但無法從網絡中獲得中間值的輸出。

解決方案

Keras Modelpredict函數有一個默認參數batch_size=32 您可以通過兩種方式解決它。

intermediate_output = intermediate_layer_model.predict(model_input, batch_size=N)

或者

intermediate_output = intermediate_layer_model.predict_on_batch(model_input)

在您的代碼中,鄰接矩陣和節點特征矩陣的第一維將被拆分為 32 個批次。但是,該模型希望始終擁有完整的圖,因此您應該將批次大小設置為N (這就是你在調用model.fitmodel.fit )。

解釋

要了解為什么需要這樣做,請考慮 GCN 層在幕后執行的操作: A @ X @ W 這是一個矩陣乘法,形狀為 (N, N) x (N, F) x (F, F')。 請注意乘法的內部維度如何始終相同:N 與 N 和 F 與 F。

現在,如果您進行批處理,則將 A 和 X 的第一個維度設置為 B=32。 這給你一個乘法 (B, N) x (B, F) x (F, F')。 看看第一個乘法的內部維度如何不再匹配? 這是 TF 引發的錯誤。 它告訴你:

Cannot multiply A and B because inner dimension does not match: 2708 vs. 32

在這種情況下,N=2708 和 B=32。

干杯

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM