簡體   English   中英

如何使用不可微的損失函數?

[英]How to use a loss function that is not differentiable?

我試圖在一個完全連接的神經網絡的輸出端找到一個密碼本,該神經網絡的選擇點應使所產生的密碼本之間的最小距離(歐幾里得范數)最大。 神經網絡的輸入是需要映射到輸出空間的更高維度的點。

例如,如果輸入維為2,輸出維為3,則以下映射(以及任何排列)最有效:00-000、01-011、10-101、11-110

import tensorflow as tf
import numpy as np
import itertools


input_bits = tf.placeholder(dtype=tf.float32, shape=[None, 2], name='input_bits')
code_out = tf.placeholder(dtype=tf.float32, shape=[None, 3], name='code_out')
np.random.seed(1331)


def find_code(message):
    weight1 = np.random.normal(loc=0.0, scale=0.01, size=[2, 3])
    init1 = tf.constant_initializer(weight1)
    out = tf.layers.dense(inputs=message, units=3, activation=tf.nn.sigmoid, kernel_initializer=init1)
    return out


code = find_code(input_bits)

distances = []
for i in range(0, 3):
    for j in range(i+1, 3):
        distances.append(tf.linalg.norm(code_out[i]-code_out[j]))
min_dist = tf.reduce_min(distances)
# avg_dist = tf.reduce_mean(distances)

loss = -min_dist

opt = tf.train.AdamOptimizer().minimize(loss)

init_variables = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_variables)

saver = tf.train.Saver()

count = int(1e4)

for i in range(count):
    input_bit = [list(k) for k in itertools.product([0, 1], repeat=2)]
    code_preview = sess.run(code, feed_dict={input_bits: input_bit})
    sess.run(opt, feed_dict={input_bits: input_bit, code_out: code_preview})

由於損失函數本身不可微,所以我得到了錯誤

ValueError: No gradients provided for any variable, check your graph for ops that do not support gradients, between variables 

我是在做傻事還是有辦法避免這種情況? 在這方面的任何幫助表示贊賞。 提前致謝。

您的損失函數在某些參數上必須是可微的。 在您的情況下,沒有參數,因此您將計算常數函數的導數為0。此外,在您的代碼中,您有以下幾行:

code = find_code(input_bits)

不再使用。 根據代碼,假設您要更改此行:

distances.append(tf.linalg.norm(code_out[i]-code_out[j]))

至:

distances.append(tf.linalg.norm(code[i]-code_out[j]))

因此,您將使用自己擁有的tf.layers.dense ,從而包含一個可用於計算損耗相對於該參數的梯度的參數。


此外,您不必擔心TF操作是否可微。 實際上,所有TF操作都是可區分的。 當涉及到tf.reduce_min() ,請檢查此鏈接

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM