簡體   English   中英

Keras model 帶輸入乘法密集層

[英]Keras model with input multiply dense layer

嘗試創建一個簡單的 keras model,其中 model 的 output 是輸入乘以密集層元素。


inputs = tf.keras.Input(shape=256)

weightLayer = tf.keras.layers.Dense(256)
multipled = tf.keras.layers.Dot(axes=1)([inputs,weightLayer])
model = tf.keras.Model(inputs, multipled)

但是,這給了我“N.netype object 不可訂閱”錯誤。 我假設這是因為點層的輸入形狀面臨問題? 我該如何解決這個問題?

Dense層必須接收某種輸入:

import tensorflow as tf

inputs = tf.keras.layers.Input(shape=256)
weightLayer = tf.keras.layers.Dense(256)
multipled = tf.keras.layers.Dot(axes=1)([inputs, weightLayer(inputs)])
model = tf.keras.Model(inputs, multipled)

否則只需定義一個權重矩陣並將其與您的輸入元素相乘。 例如,通過使用自定義圖層:

import tensorflow as tf

class WeightedLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(WeightedLayer, self).__init__()
    self.num_outputs = num_outputs
    self.dot_layer = tf.keras.layers.Dot(axes=1)

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]),
                                         self.num_outputs])

  def call(self, inputs):
    return self.dot_layer([inputs, self.kernel])


inputs = tf.keras.layers.Input(shape=256)
weighted_layer = WeightedLayer(256)
multipled = weighted_layer(inputs)
model = tf.keras.Model(inputs, multipled)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM