简体   繁体   中英

Input 'y' of 'Mul' Op has type float32 that does not match type int32 of argument 'x'

When I use this code on Linux. It works. But on windows it doesn't. By the way my python version is 3.5 on my windows

with graph.as_default():

 train_inputs = tf.placeholder(tf.int32, shape=[batch_size])
 train_labels = tf.placeholder(tf.int32, shape=[batch_size, 1])
 valid_dataset = tf.constant(valid_examples, dtype=tf.int32)


with tf.device('/cpu:0'):

 embeddings = tf.Variable(
    tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0))
 embed = tf.nn.embedding_lookup(embeddings, train_inputs)


 nce_weights = tf.Variable(
    tf.truncated_normal([vocabulary_size, embedding_size],
                        stddev=1.0 / math.sqrt(embedding_size)))
 nce_biases = tf.Variable(tf.zeros([vocabulary_size]))


loss = tf.reduce_mean(
  tf.nn.nce_loss(nce_weights, nce_biases, embed, train_labels,num_sampled, vocabulary_size))

the new version of tensorflow has changed the parameters order of ncs_loss.

try to change as

tf.nn.nce_loss(nce_weights, nce_biases, train_labels, embed, num_sampled, vocabulary_size)

I ran into this error, as well. with code very similar to yours. When I ran it on Floydhub using env=tensorflow (which means Tensorflow 1.1.0 + Keras 2.0.4 on Python3), it threw the above error.

However, it ran fine after i changed the environment to use tensorflow-1.0 (Tensorflow 1.0.0 + Keras 1.2.2 on Python3).

You need to convert train_labels type to float32 . [You already mentioned that train_labels is of type int32 and embed is of type float32 .]

This is how you convert int32 type to float32

tf.cast(train_labels, tf.float32)

then calculate the loss.

I ran into the same problem but with different loss function. You are missing the parameter name, pass parameter with name and error will go away. Check below line of code fro example.

    loss = tf.reduce_mean(
                  tf.nn.nce_losss(weights=nce_weights, biases=nce_biases, 
                                  inputs=embed, labels=train_labels, 
                                  num_sampled=num_sampled,
                                  num_classes=vocabulary_size))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM