簡體   English   中英

將單位矩陣連接到每個向量

[英]Concatenate identity matrix to each vector

我想通過向輸入向量添加幾個不同的后綴來修改我的輸入。 例如,如果(單個)輸入是[1, 5, 9, 3]我想像這樣創建三個向量(存儲為矩陣):

[[1, 5, 9, 3, 1, 0, 0],
 [1, 5, 9, 3, 0, 1, 0],
 [1, 5, 9, 3, 0, 0, 1]]

當然,這只是一個觀察結果,因此在這種情況下(None, 4)模型的輸入是(None, 4) 簡單的方法是在其他地方准備輸入數據(最有可能是numpy)並相應地調整輸入的形狀。 我可以做,但我更喜歡在 TensorFlow/Keras 中做。

我已將此問題隔離到此代碼中:

import keras.backend as K
from keras import Input, Model
from keras.layers import Lambda


def build_model(dim_input: int, dim_eye: int):
    input = Input((dim_input,))
    concat = Lambda(lambda x: concat_eye(x, dim_input, dim_eye))(input)
    return Model(inputs=[input], outputs=[concat])


def concat_eye(x, dim_input, dim_eye):
    x = K.reshape(x, (-1, 1, dim_input))
    x = K.repeat_elements(x, dim_eye, axis=1)
    eye = K.expand_dims(K.eye(dim_eye), axis=0)
    eye = K.tile(eye, (-1, 1, 1))
    out = K.concatenate([x, eye], axis=2)
    return out


def main():
    import numpy as np

    n = 100
    dim_input = 20
    dim_eye = 3

    model = build_model(dim_input, dim_eye)
    model.compile(optimizer='sgd', loss='mean_squared_error')

    x_train = np.zeros((n, dim_input))
    y_train = np.zeros((n, dim_eye, dim_eye + dim_input))
    model.fit(x_train, y_train)


if __name__ == '__main__':
    main()

問題似乎出在tile函數中的shape參數-1中。 我試圖用1None替換它。 每個都有自己的錯誤:

  • -1model.fit期間model.fit

     tensorflow.python.framework.errors_impl.InvalidArgumentError: Expected multiples[0] >= 0, but got -1
  • 1 : 誤差model.fit

     tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [32,3,20] vs. shape[1] = [1,3,3]
  • None :在build_model期間build_model

     Failed to convert object of type <class 'tuple'> to Tensor. Contents: (None, 1, 1). Consider casting elements to a supported type.

您需要使用K.shape()來獲取輸入張量的符號形狀 這是因為批量大小為None ,因此將K.int_shape(x)[0]None-1作為K.tile()的第二個參數的K.tile()將不起作用:

eye = K.tile(eye, (K.shape(x)[0], 1, 1))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM