简体   繁体   中英

Using a custom step activation function in Keras results in “'tuple' object has no attribute '_keras_shape'” error. How to resolve this?

I'm trying to implement a binary custom activation function in the output layer of a Keras model.

This is my trial:

def binary_activation(x):
    ones = tf.ones(tf.shape(x), dtype=x.dtype.base_dtype)
    zeros = tf.zeros(tf.shape(x), dtype=x.dtype.base_dtype)
    def grad(dy):
        return dy
    return switch(x > 0.5, ones, zeros), grad

Similar to here . But I get the following error back:

File "/usr/local/lib/python3.6/dist-packages/spyder_kernels/customize/spydercustomize.py", line 786, in runfile execfile(filename, namespace)

File "/usr/local/lib/python3.6/dist-packages/spyder_kernels/customize/spydercustomize.py", line 110, in execfile exec(compile(f.read(), filename, 'exec'), namespace)

File "/home/marlon/Área de Trabalho/omj_project/predicting_change.py", line 85, in model = baseline_model()

File "/home/marlon/Área de Trabalho/omj_project/predicting_change.py", line 80, in baseline_model model.add(Dense(1, activation=binary_activation))

File "/usr/local/lib/python3.6/dist-packages/keras/engine/sequential.py", line 181, in add output_tensor = layer(self.outputs[0])

File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 497, in call arguments=user_kwargs)

File "/usr/local/lib/python3.6/dist-packages/keras/engine/base_layer.py", line 565, in _add_inbound_node output_tensors[i]._keras_shape = output_shapes[i]

AttributeError: 'tuple' object has no attribute '_keras_shape'

Thanks for any help.

you need add

@tf.custom_gradient

on top of your code like other comment that you mentioned.

@tf.custom_gradient
def binary_activation(x):
    ones = tf.ones(tf.shape(x), dtype=x.dtype.base_dtype)
    
    zeros = tf.zeros(tf.shape(x), dtype=x.dtype.base_dtype)
    res = tf.keras.backend.switch(x > 0.5, ones, zeros)
    def grad(dy):
        return dy
    return res, grad

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM