繁体   English   中英

如何使用不可微的损失函数?

[英]How to use a loss function that is not differentiable?

我试图在一个完全连接的神经网络的输出端找到一个密码本,该神经网络的选择点应使所产生的密码本之间的最小距离(欧几里得范数)最大。 神经网络的输入是需要映射到输出空间的更高维度的点。

例如,如果输入维为2,输出维为3,则以下映射(以及任何排列)最有效:00-000、01-011、10-101、11-110

import tensorflow as tf
import numpy as np
import itertools


input_bits = tf.placeholder(dtype=tf.float32, shape=[None, 2], name='input_bits')
code_out = tf.placeholder(dtype=tf.float32, shape=[None, 3], name='code_out')
np.random.seed(1331)


def find_code(message):
    weight1 = np.random.normal(loc=0.0, scale=0.01, size=[2, 3])
    init1 = tf.constant_initializer(weight1)
    out = tf.layers.dense(inputs=message, units=3, activation=tf.nn.sigmoid, kernel_initializer=init1)
    return out


code = find_code(input_bits)

distances = []
for i in range(0, 3):
    for j in range(i+1, 3):
        distances.append(tf.linalg.norm(code_out[i]-code_out[j]))
min_dist = tf.reduce_min(distances)
# avg_dist = tf.reduce_mean(distances)

loss = -min_dist

opt = tf.train.AdamOptimizer().minimize(loss)

init_variables = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_variables)

saver = tf.train.Saver()

count = int(1e4)

for i in range(count):
    input_bit = [list(k) for k in itertools.product([0, 1], repeat=2)]
    code_preview = sess.run(code, feed_dict={input_bits: input_bit})
    sess.run(opt, feed_dict={input_bits: input_bit, code_out: code_preview})

由于损失函数本身不可微,所以我得到了错误

ValueError: No gradients provided for any variable, check your graph for ops that do not support gradients, between variables 

我是在做傻事还是有办法避免这种情况? 在这方面的任何帮助表示赞赏。 提前致谢。

您的损失函数在某些参数上必须是可微的。 在您的情况下,没有参数,因此您将计算常数函数的导数为0。此外,在您的代码中,您有以下几行:

code = find_code(input_bits)

不再使用。 根据代码,假设您要更改此行:

distances.append(tf.linalg.norm(code_out[i]-code_out[j]))

至:

distances.append(tf.linalg.norm(code[i]-code_out[j]))

因此,您将使用自己拥有的tf.layers.dense ,从而包含一个可用于计算损耗相对于该参数的梯度的参数。


此外,您不必担心TF操作是否可微。 实际上,所有TF操作都是可区分的。 当涉及到tf.reduce_min() ,请检查此链接

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM