[英]Custom Keras loss function with Keras by discriminator with Gan neural network
我使用下面的这段代码得到了 GAN 神经网络的鉴别器:
import tensorflow as tf
import numpy as np
from IPython.display import display, Audio
tf.reset_default_graph()
saver = tf.train.import_meta_graph('./infer/infer.meta')
graph = tf.get_default_graph()
sess = tf.InteractiveSession()
saver.restore(sess, tf.train.latest_checkpoint('model/'))
# here is z with underline, it doesn't showing ceractly in stack.
# I use random data to test this function.
_z = np.random.uniform(-1., 1., size=[5, 257])
x = graph.get_tensor_by_name('x:0')
D_z = graph.get_tensor_by_name('D_z:0')
D_z = sess.run(D_z, {x: _z})
print(D_z)
我想创建一个函数来自定义 keras 损失函数:
# Load the graph
tf.reset_default_graph()
saver = tf.train.import_meta_graph('./infer/infer.meta')
graph = tf.get_default_graph()
sess = tf.InteractiveSession()
saver.restore(sess, tf.train.latest_checkpoint('model/'))
def gan_loss(y_true, y_pred):
_z = y_pred
x = graph.get_tensor_by_name('x:0')
D_z = graph.get_tensor_by_name('D_z:0')
D_z = sess.run(D_z, {x: _z})
return D_z
我遇到了向我展示的问题:无法提供 tesor,您必须提供 numpy 或其他类型的数据。
类型错误:提要的值不能是 tf.Tensor 对象。 可接受的提要值包括 Python 标量、字符串、列表或 numpy ndarray。
我喜欢 Stak 中的相关问题: 在 Keras 中使用 K.eval() 将 Tensor 转换为 np.array 返回 InvalidArgumentError
X = tf.placeholder(tf.float32, [None, 257], name='x')
D_z, h3 = discriminator(X)
D_z = tf.identity(D_z, name='D_z')
D_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='GAN/Discriminator')
# global_step = tf.train.get_or_create_global_step()
saver = tf.train.Saver(D_vars)
infer_dir = './infer/'
tf.train.write_graph(tf.get_default_graph(), infer_dir, 'infer.pbtxt')
infer_metagraph_fp = os.path.join(infer_dir, 'infer.meta')
tf.train.export_meta_graph(
filename=infer_metagraph_fp,
clear_devices=True,
saver_def=saver.as_saver_def())
tf.reset_default_graph()
我能够使用下面的简单代码重现您的错误,我正在向feed_dict
提供张量 -
重现错误的代码 -
%tensorflow_version 1.x
import tensorflow as tf
print(tf.__version__)
import numpy as np
x = tf.placeholder(tf.float32)
y = x * 42
with tf.Session() as sess:
a = tf.constant(2)
train_accuracy = y.eval(session=sess,feed_dict={x: a})
print(train_accuracy)
输出 -
1.15.2
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-18-44556cb4551b> in <module>()
10 with tf.Session() as sess:
11 a = tf.constant(2)
---> 12 train_accuracy = y.eval(session=sess,feed_dict={x: a})
13 print(train_accuracy)
3 frames
/tensorflow-1.15.2/python3.6/tensorflow_core/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1129 'For reference, the tensor object was ' +
1130 str(feed_val) + ' which was passed to the '
-> 1131 'feed with key ' + str(feed) + '.')
1132
1133 subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles. For reference, the tensor object was Tensor("Const_6:0", shape=(), dtype=int32) which was passed to the feed with key Tensor("Placeholder_9:0", dtype=float32).
当我将tensor
转换为feed_dict
numpy
类型时,我能够修复它。 所以在你的情况下,将y_pred
转换为numpy
类型。
固定代码 -
%tensorflow_version 1.x
import tensorflow as tf
print(tf.__version__)
import numpy as np
x = tf.placeholder(tf.float32)
y = x * 42
with tf.Session() as sess:
a = tf.constant(2)
a = np.array(a.eval())
train_accuracy = y.eval(session=sess,feed_dict={x: b})
print(train_accuracy)
输出 -
1.15.2
84.0
希望这能回答你的问题。 快乐学习。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.