簡體   English   中英

Tensorflow - Keras 斷開連接圖

[英]Tensorflow - Keras disconnected graph

Tensorflow 版本:2.x

蟒蛇:3.7.4

斷開連接圖:我試圖復制下面的模型架構,但是當我嘗試在 Keras 中繪制模型時,右側部分似乎斷開連接。 我已經將隱藏矩陣 HQ(For question) 和 HA(For answer) 作為輸入到注意力層(我們可以在下面總結中看到 Coattention 層的輸入 - 輸入形狀是 (512,600) 和 (512, 600) 和 Coattention矩陣 CQ 和 CA 的輸出形狀也相同)。 請幫助我理解這種斷開連接。 這需要更正還是可以忽略?

最終模型:

inputs = [input_text1, input_text2]
outputs = score_oq_oa
model = Model(inputs=inputs, outputs=outputs)

model.summary()

在此處輸入圖片說明

預期模型架構: 在此處輸入圖片說明

模型生成圖:為什么右側斷開連接? 請幫我理解。 我沒有在問答的雙向層之后使用連接層,但我只是將兩個雙向層的輸出矩陣作為輸入傳遞給注意層,如上所述。 在此處輸入圖片說明

問題更新為 Coattention 層的代碼如下:

這里 HQ 和 HA 是我們在模型架構中看到的兩個獨立雙向層的隱藏狀態矩陣/輸出。

class coattention(tf.keras.layers.Layer):

    def __init__(self):
        super(coattention, self).__init__()

    def call(self, HQ, HA):  

        L = tf.linalg.matmul(HA, HQ, transpose_a = True, transpose_b = False)
        AQ = tf.nn.softmax(L, axis = 1)
        AA = tf.nn.softmax(tf.transpose(L), axis = 1)

        CQ = tf.linalg.matmul(HA, AQ, transpose_a = False, transpose_b = False)
        CA = tf.linalg.matmul(HQ, AA, transpose_a = False, transpose_b = False)

        return CQ, CA


coattention_layer = coattention()
CQ, CA = coattention_layer(HQ, HA)
print ("Shape of Context vector of Question (CQ): ", CQ.shape)
print ("Shape of Context vector of Answer   (CA): ", CA.shape)

問題(CQ)的上下文向量的形狀:(512, 600)

Answer (CA) 的上下文向量的形狀:(512, 600)

因為你沒有提供代碼,我相信你忘記調用以 Bidirectional_7 層作為輸入的塗層。

這是示例代碼

Ha = Input(shape=(1,2,3), name='Ha')
Hq = Input(shape=(1,2,3), name='Hq')

your_coattention_layer = Dense(12, name='your_coattention_layer')

# this part that I think you forgot
Ca = your_coattention_layer(Ha)
cQ = your_coattention_layer(Hq)


out1 = Dense(123, name='your_Ca_layer')(Ca)
out2 = Dense(123, name='your_Cq_later')(cQ)
M = Model(inputs=[Ha,Hq], outputs=[out1,out2])
M.summary()

from keras.utils import plot_model
plot_model(M, to_file='Example.png')

這是模型圖。

在此處輸入圖片說明

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM