简体   繁体   中英

Write an own layer with trainable parameters in Tensorflow

I want to write an easy, own Exponentiation layer in Tensorflow 2. It should take n inputs [x_1, ..., x_n] and output some powers [x_1^e_1, ..., x_n^e_n], where e_1, ..., e_n are trainable parameters.

For example, this Exponentiation layer together with a Dense layer (with output dimension 1) could learn any function of the form a_1 x_1^e_1+...+a_n x_n^e_n, which is a simple extension of normal, linear regression.

However, I have no luck with getting t to work. So far, I have written the following:

import tensorflow as tf

class Exponent(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        
    def build(self, input_shape):
        self.exp = self.add_weight(name='Exponent', shape=input_shape, initializer=tf.constant_initializer(value=1.), trainable=True)
        super().build(input_shape)
        
    def call(self, inputs, training=False):
        return tf.math.pow(inputs, self.exp)

I can instantiate an object of this class, and everything works fine.

e = Exponent()
e.build(input_shape=(2,))
e([[1., 2.]]) # works fine

If I try to embed it in a model, it throws an error when trying to predict (or also fit):

model = tf.keras.Sequential([
    Exponent(),
    tf.keras.layers.Dense(1)
])

model.compile(
    loss='mse',
    optimizer='sgd'
)

model.predict([[1., 2., 3.]])
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-401-2ea9d74ffa37> in <module>
      9 )
     10 
---> 11 model.predict([[1., 2., 3.]])

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)
    128       raise ValueError('{} is not supported in multi-worker mode.'.format(
    129           method.__name__))
--> 130     return method(self, *args, **kwargs)
    131 
    132   return tf_decorator.make_decorator(

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
   1597           for step in data_handler.steps():
   1598             callbacks.on_predict_batch_begin(step)
-> 1599             tmp_batch_outputs = predict_function(iterator)
   1600             if data_handler.should_sync:
   1601               context.async_wait()

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    778       else:
    779         compiler = "nonXla"
--> 780         result = self._call(*args, **kwds)
    781 
    782       new_tracing_count = self._get_tracing_count()

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    821       # This is the first call of __call__, so we have to initialize.
    822       initializers = []
--> 823       self._initialize(args, kwds, add_initializers_to=initializers)
    824     finally:
    825       # At this point we know that the initialization is complete (or less

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py in _initialize(self, args, kwds, add_initializers_to)
    695     self._concrete_stateful_fn = (
    696         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 697             *args, **kwds))
    698 
    699     def invalid_creator_scope(*unused_args, **unused_kwds):

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2853       args, kwargs = None, None
   2854     with self._lock:
-> 2855       graph_function, _, _ = self._maybe_define_function(args, kwargs)
   2856     return graph_function
   2857 

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in _maybe_define_function(self, args, kwargs)
   3211 
   3212       self._function_cache.missed.add(call_context_key)
-> 3213       graph_function = self._create_graph_function(args, kwargs)
   3214       self._function_cache.primary[cache_key] = graph_function
   3215       return graph_function, args, kwargs

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3073             arg_names=arg_names,
   3074             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3075             capture_by_value=self._capture_by_value),
   3076         self._function_attributes,
   3077         function_spec=self.function_spec,

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    984         _, original_func = tf_decorator.unwrap(python_func)
    985 
--> 986       func_outputs = python_func(*func_args, **func_kwargs)
    987 
    988       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py in wrapped_fn(*args, **kwds)
    598         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    599         # the function a weak reference to itself to avoid a reference cycle.
--> 600         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    601     weak_wrapped_fn = weakref.ref(wrapped_fn)
    602 

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\func_graph.py in wrapper(*args, **kwargs)
    971           except Exception as e:  # pylint:disable=broad-except
    972             if hasattr(e, "ag_error_metadata"):
--> 973               raise e.ag_error_metadata.to_exception(e)
    974             else:
    975               raise

ValueError: in user code:

    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py:1462 predict_function  *
        return step_function(self, iterator)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py:1452 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\distribute\distribute_lib.py:1211 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\distribute\distribute_lib.py:2585 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\distribute\distribute_lib.py:2945 _call_for_each_replica
        return fn(*args, **kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py:1445 run_step  **
        outputs = model.predict_step(data)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\training.py:1418 predict_step
        return self(x, training=False)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\base_layer.py:985 __call__
        outputs = call_fn(inputs, *args, **kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\sequential.py:386 call
        outputs = layer(inputs, **kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\base_layer.py:982 __call__
        self._maybe_build(inputs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\base_layer.py:2643 _maybe_build
        self.build(input_shapes)  # pylint:disable=not-callable
    <ipython-input-357-5a69bae7457e>:8 build
        self.exp = self.add_weight(name='Exponent', shape=input_shape, initializer=tf.constant_initializer(value=2.), trainable=True)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\base_layer.py:614 add_weight
        caching_device=caching_device)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\training\tracking\base.py:750 _add_variable_with_custom_getter
        **kwargs_for_getter)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\keras\engine\base_layer_utils.py:145 make_variable
        shape=variable_shape if variable_shape else None)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\variables.py:260 __call__
        return cls._variable_v1_call(*args, **kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\variables.py:221 _variable_v1_call
        shape=shape)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\distribute\distribute_lib.py:2857 creator
        return next_creator(**kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\distribute\distribute_lib.py:2857 creator
        return next_creator(**kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\distribute\distribute_lib.py:2857 creator
        return next_creator(**kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py:685 variable_capturing_scope
        lifted_initializer_graph=lifted_initializer_graph, **kwds)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\variables.py:264 __call__
        return super(VariableMetaclass, cls).__call__(*args, **kwargs)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\eager\def_function.py:226 __init__
        initial_value() if init_from_fn else initial_value,
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\ops\init_ops_v2.py:263 __call__
        self.value, dtype=dtype, shape=shape)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\constant_op.py:264 constant
        allow_broadcast=True)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\constant_op.py:275 _constant_impl
        return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\constant_op.py:321 _constant_eager_impl
        return _eager_fill(shape.as_list(), t, ctx)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\constant_op.py:54 _eager_fill
        dims = convert_to_eager_tensor(dims, ctx, dtypes.int32)
    C:\Users\robkuble\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\constant_op.py:98 convert_to_eager_tensor
        return ops.EagerTensor(value, ctx.device_name, dtype)

    ValueError: Attempt to convert a value (None) with an unsupported type (<class 'NoneType'>) to a Tensor.

Do you see, why there is a none type appearing somewhere? Thank you very much!

Best Robert

The first dimension in the input_shape argument of the build method corresponds to the batch size, which typically is None (this means that the network can work with batches of any size). For example, when calling model.predict([[1., 2., 3.]]) the input_shape will be (None, 3) .

This means you need to change your implementation of the build method slightly to use input_shape[-1] instead of input_shape

...

  def build(self, input_shape):
    self.exp = self.add_weight(name='Exponent', shape=input_shape[-1], initializer=tf.constant_initializer(value=1.), trainable=True)
    super().build(input_shape)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM