簡體   English   中英

函數式 API 中的 Keras Multiply() 層

[英]Keras Multiply() layer in functional API

在新的 API 變化下,如何在 Keras 中進行元素級的層乘法? 在舊的 API 下,我會嘗試這樣的事情:

merge([dense_all, dense_att], output_shape=10, mode='mul')

我試過這個(MWE):

from keras.models import Model
from keras.layers import Input, Dense, Multiply

def sample_model():
        model_in = Input(shape=(10,))
        dense_all = Dense(10,)(model_in)
        dense_att = Dense(10, activation='softmax')(model_in)
        att_mull = Multiply([dense_all, dense_att]) #merge([dense_all, dense_att], output_shape=10, mode='mul')
        model_out = Dense(10, activation="sigmoid")(att_mull)
        return 0

if __name__ == '__main__':
        sample_model()

完整跟蹤:

Using TensorFlow backend.
Traceback (most recent call last):
  File "testJan17.py", line 13, in <module>
    sample_model()
  File "testJan17.py", line 8, in sample_model
    att_mull = Multiply([dense_all, dense_att]) #merge([dense_all, dense_att], output_shape=10, mode='mul')
TypeError: __init__() takes exactly 1 argument (2 given)

編輯:

我嘗試實現 tensorflow 的元素乘法函數。 當然,結果不是一個Layer()實例,所以它不起作用。 這是對后代的嘗試:

def new_multiply(inputs): #assume two only - bad practice, but for illustration...
        return tf.multiply(inputs[0], inputs[1])


def sample_model():
        model_in = Input(shape=(10,))
        dense_all = Dense(10,)(model_in)
        dense_att = Dense(10, activation='softmax')(model_in) #which interactions are important?
        new_mult = new_multiply([dense_all, dense_att])
        model_out = Dense(10, activation="sigmoid")(new_mult)
        model = Model(inputs=model_in, outputs=model_out)
        model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
        return model

使用keras > 2.0:

from keras.layers import multiply
output = multiply([dense_all, dense_att])

在函數式 API 下,您只需使用multiply函數,注意小寫的“m”。 如您所見,Multiply 類是一個層,旨在與順序 API 一起使用。

https://keras.io/layers/merge/#multiply_1中的更多信息

您需要在前面再添加一個左括號。

from keras.layers import Multiply
att_mull = Multiply()([dense_all, dense_att])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM