[英]How tensorflow graph regularization (NSL) affects triplet semihard loss (TFA)
[英]tensorflow triplet_semihard_loss doesnt change after multiple epochs
我正在編寫一個訓練自定義人臉重新識別系統的基本版本(使用 mnist 數據作為構建塊,tensorflow 定義了半硬三元組損失函數),但損失 /acc 顯示在多個時期后絕對沒有變化。 下面的代碼
def kerasTriplet( label, pred ):
print('-------------------------')
print( label )
print( pred )
def lossFunc( y_true, y_pred ):
return tf.contrib.losses.metric_learning.triplet_semihard_loss( label, pred, 0.6 )
#return nonTFTripletLoss.batch_hard_triplet_loss( label, pred, 0.6 )
return lossFunc
def gen( trg, tgt ):
batch_sz = BATCH_SZ
start = np.random.randint( 0, len( trg ) - BATCH_SZ )
return trg[ start: start+batch_sz] , tgt[ start: start+batch_sz ]
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
n_train, height, width = x_train.shape
x_train = x_train.reshape(n_train, height, width, 1).astype('float32')
x_train = x_train[ :(int(len(x_train)/BATCH_SZ))*BATCH_SZ ]
x_train /= 255
num_classes = 10
y_train_orig = y_train
y_train_orig = y_train_orig[ :(int(len(x_train)/BATCH_SZ))*BATCH_SZ ]
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
input_shape = (28, 28, 1)
sequence_input = tf.keras.layers.Input(shape=input_shape , dtype='float32')
batch_inp, batch_tgt = gen( x_train, y_train_orig )
x = tf.keras.layers.Conv2D( 512, (3,3), activation='relu')( batch_inp )
x = tf.keras.layers.Conv2D( 256, (3,3), activation='relu')( x )
x = tf.keras.layers.Conv2D( 128, (3,3), activation='relu')( x )
x = tf.keras.layers.Flatten()(x)
img_embedding = tf.keras.layers.Dense( 128 )(x)
## since triplet loss requires embedding to be l2 normalized
l2_embed = tf.keras.backend.l2_normalize( img_embedding, -1 )
model = tf.keras.models.Model( sequence_input , l2_embed )
model.compile( loss=kerasTriplet( batch_tgt, img_embedding ) , optimizer='adam', metrics=['acc'] )
model.fit(x_train, y_train_orig, batch_size=BATCH_SZ, epochs=10 , verbose=1)
我預計損失和 acc 會發生變化,即使幅度不大(因為我只運行了 10 個 epoch),但它完全一樣。 我確信這與我的代碼有關。 就是指不上它
您計算 l2_embedding 不正確。 嘗試這個
l2_embed = tf.keras.backend.l2_normalize(img_embedding,軸= 1)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.