简体   繁体   中英

InvalidArgumentError: cannot compute Sub as input #1(zero-based) was expected to be a uint8 tensor but is a float tensor [Op:Sub]

Question

Please help understand the cause of the error and how to resolve.

Code

import tensorflow as tf
import numpy as np

fashion_mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_full = np.concatenate((x_train, x_test), axis=0)

layer = tf.keras.layers.experimental.preprocessing.Normalization()
layer.adapt(x_full)
layer(x_train)

Error

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-16-699c47b6db55> in <module>
----> 1 ds = layer(x_train)

~/conda/envs/tensorflow/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
    966           with base_layer_utils.autocast_context_manager(
    967               self._compute_dtype):
--> 968             outputs = self.call(cast_inputs, *args, **kwargs)
    969           self._handle_activity_regularization(inputs, outputs)
    970           self._set_mask_metadata(inputs, outputs, input_masks)

~/conda/envs/tensorflow/lib/python3.7/site-packages/tensorflow/python/keras/layers/preprocessing/normalization.py in call(self, inputs)
    109     mean = array_ops.reshape(self.mean, self._broadcast_shape)
    110     variance = array_ops.reshape(self.variance, self._broadcast_shape)
--> 111     return (inputs - mean) / math_ops.sqrt(variance)
    112 
    113   def compute_output_shape(self, input_shape):

~/conda/envs/tensorflow/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py in binary_op_wrapper(x, y)
    982     with ops.name_scope(None, op_name, [x, y]) as name:
    983       if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
--> 984         return func(x, y, name=name)
    985       elif not isinstance(y, sparse_tensor.SparseTensor):
    986         try:

~/conda/envs/tensorflow/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py in sub(x, y, name)
  10098         pass  # Add nodes to the TensorFlow graph.
  10099     except _core._NotOkStatusException as e:
> 10100       _ops.raise_from_not_ok_status(e, name)
  10101   # Add nodes to the TensorFlow graph.
  10102   _, _, _op, _outputs = _op_def_library._apply_op_helper(

~/conda/envs/tensorflow/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   6651   message = e.message + (" name: " + name if name is not None else "")
   6652   # pylint: disable=protected-access
-> 6653   six.raise_from(core._status_to_exception(e.code, message), None)
   6654   # pylint: enable=protected-access
   6655 

~/conda/envs/tensorflow/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: cannot compute Sub as input #1(zero-based) was expected to be a uint8 tensor but is a float tensor [Op:Sub]

Attempts

Tried dtype arg but same error.

layer = tf.keras.layers.experimental.preprocessing.Normalization(dtype='float32')

Divide by 1.0 fixed the issue but not sure the original cause.

x_full = np.concatenate((x_train, x_test), axis=0) / 1.0
x_train = x_train / 1.0

Does Keras only works with float32?

Related issues

The cause is preprocessing.Normalization expect float32 but your data was uint8 and thus that error.

It's actually Tensorflow's problem not Keras itself as this is faster computation.

Reminder: float and int compute in different places in the processor and each processor have different performance on different data types, for example nvidia's gpus are faster with float32 than float16 while arm cpus are faster with the 16.

Pytorch too need two variables to be same data type or it won't work.

Divide an integer with a float in python automatically give you a new float, x_train = x_train / 1.0 will make x_train float32 (or float64 or float16 depending on what you have in ~/.keras/keras.json but you have float32 here).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM