[英]Tensorflow implementation of NT_Xent contrastive loss function?
[英]How to send 2 images into 1 network in Tensorflow and calculate contrastive loss?
我目前需要做與https://github.com/tensorflow/tensorflow/issues/6776類似的操作
例如,我有一批圖像A,B,C ..,並且我將為它們生成增強的批次,如a,b,c...。
然后,我將A(B或C)發送到Inception網絡中以獲取輸出張量“ output_1”,並且我需要將A(b或c)發送到同一Inception網絡中以獲取輸出張量“ output_2”,並且我將使用||“ output1- output2” || 作為對比損失。
目前,我不確定人們通常如何在Tensorflow中處理此類操作。 我在網上搜索,但沒有找到答案(盡管我想這與網絡的“重用”有關)。
這是我的源代碼的樣子(對不起,因為我可以在此處粘貼一個簡化的版本):
class MyModel:
......
def define_my_net(self):
self.inputs_from_bloader = tf.placeholder(...)
self.input = self.inputs_from_bloader
self.output = slim.conv2d(self.input,...)
......
def update(sess, inputs):
feed_dict = utility.build_feed_dict(self.inputs_from_bloader, inputs)
sess.run([my_op_list], feed_dict = feed_dict)
......
def train():
data = importlib.import_module('some.datasets.reader')
data = data.DataReader()
model = importlib.import_module('MyModel')
model.MyModel()
model.define_my_net() ### This is where network is defined
batch = data.get_batch() ### This is where A,B,C and a,b,c are generated.
model.update(sess, batch) ### This is where training is done
我想我可以從“ batch = data.get_batch”輸出類似AaBbCc的批處理,也可以將其更改為“ batch1,batch2 = data.get_batch”,但是我不知道如何將batch1和batch2傳遞到定義的網絡中,因為它可能需要對框架進行一些架構上的修改。
如果您認為上述源代碼過於混亂,那么任何簡單的示例也將不勝感激。
您可以實例化網絡兩次。 這些實例通常稱為“塔”。 兩個塔將使用相同的變量,但具有不同的輸入和操作。
根據您使用的高級API,您應該查找一些用於控制變量reuse
標志,以便在構建第二個塔時不會創建新變量。 例如,在此處https://www.tensorflow.org/guide/variables搜索reuse
。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.