簡體   English   中英

圖像理解-CNN三重態損失

[英]Image understanding - CNN Triplet loss

我是NN的新手,正在嘗試創建一個用於圖像理解的簡單NN。

我嘗試使用三重態損失方法,但是不斷出現錯誤,使我覺得我缺少一些基本概念。

我的代碼是:

def triplet_loss(x):
  anchor, positive, negative = tf.split(x, 3)

  pos_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, positive)), 1)
  neg_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, negative)), 1)

  basic_loss = tf.add(tf.subtract(pos_dist, neg_dist), ALPHA)
  loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0)

  return loss


def build_model(input_shape):
  K.set_image_data_format('channels_last')

  positive_example = Input(shape=input_shape)
  negative_example = Input(shape=input_shape)
  anchor_example = Input(shape=input_shape)

  embedding_network = create_embedding_network(input_shape)

  positive_embedding = embedding_network(positive_example)
  negative_embedding = embedding_network(negative_example)
  anchor_embedding = embedding_network(anchor_example)

  merged_output = concatenate([anchor_embedding, positive_embedding, negative_embedding])
  loss = Lambda(triplet_loss, (1,))(merged_output)

  model = Model(inputs=[anchor_example, positive_example, negative_example],
              outputs=loss)
  model.compile(loss='mean_absolute_error', optimizer=Adam())

  return model



def create_embedding_network(input_shape):
  input_shape = Input(input_shape)
  x = Conv2D(32, (3, 3))(input_shape)
  x = PReLU()(x)
  x = Conv2D(64, (3, 3))(x)
  x = PReLU()(x)

  x = Flatten()(x)
  x = Dense(10, activation='softmax')(x)
  model = Model(inputs=input_shape, outputs=x)
  return model

使用以下命令讀取每個圖像:

imageio.imread(imagePath, pilmode="RGB")

以及每個圖像的形狀:

(1024, 1024, 3)

然后我使用自己的三元組方法(僅創建3組錨點(正向和負向))

triplets = get_triplets(data)
triplets.shape

形狀為(示例數,三元組,x_image,y_image,通道數(RGB)):

(20, 3, 1024, 1024, 3)

然后我使用build_model函數:

model = build_model((1024, 1024, 3))

問題從這里開始:

model.fit(triplets, y=np.zeros(len(triplets)), batch_size=1)

對於這行代碼,當我嘗試訓練模型時,出現此錯誤:

錯誤

有關更多詳細信息,我的代碼在此便攜式筆記本中

我使用的圖片可以在此驅動器中找到,以便無縫運行-將該文件夾放置在

我的雲端硬盤/ Colab筆記本/圖像/

對於任何也掙扎的人

我的問題實際上是每個觀察的維度。 通過更改注釋中建議的尺寸

(?, 1024, 1024, 3)

解決方案已更新的colab筆記本

附言:我還將圖片的大小更改為256 * 256,以便代碼在我的PC上運行得更快。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM