简体   繁体   中英

How can I implement KL-divergence regularization for Keras?

This is a follow-up question for this question Keras backend mean function: " 'float' object has no attribute 'dtype' "?

I am trying to make a new regularizer for Keras. Here is my code

import keras
from keras import initializers
from keras.models import Model, Sequential
from keras.layers import Input, Dense, Activation
from keras import regularizers
from keras import optimizers
from keras import backend as K

kullback_leibler_divergence = keras.losses.kullback_leibler_divergence

def kl_divergence_regularizer(inputs):
        means = K.mean((inputs))
        rho=0.05
        down = 0.05 * K.ones_like(means)
        up = (1 - 0.05) * K.ones_like(means)
        return 0.5 *(0.01 * (kullback_leibler_divergence(down, means)
                      + kullback_leibler_divergence(up, 1 - means)))

model = Sequential([
    Dense(900, input_shape=(x_train_s.shape[1],),kernel_initializer='random_uniform',kernel_regularizer=kl_divergence_regularizer),
    Activation('elu'),
    Dense(x_train_s.shape[1],kernel_initializer='random_uniform'),
    Activation('tanh')
])

model.compile(optimizer='adam',loss='mean_squared_error')

model.fit(x_train_s, y_train_s, epochs=5)

Here is the error:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1658   try:
-> 1659     c_op = c_api.TF_FinishOperation(op_desc)
   1660   except errors.InvalidArgumentError as e:

InvalidArgumentError: Invalid reduction dimension -1 for input with 0 dimensions. for 'dense_3/weight_regularizer/Sum' (op: 'Sum') with input shapes: [], [] and with computed input tensors: input[1] = <-1>.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-4-9f4dfbe34659> in <module>
     39     Activation('elu'),
     40     Dense(x_train_s.shape[1],kernel_initializer='random_uniform'),
---> 41     Activation('tanh')
     42 ])
     43 

C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\sequential.py in __init__(self, layers, name)
     91         if layers:
     92             for layer in layers:
---> 93                 self.add(layer)
     94 
     95     @property

C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\sequential.py in add(self, layer)
    163                     # and create the node connecting the current layer
    164                     # to the input layer we just created.
--> 165                     layer(x)
    166                     set_inputs = True
    167             else:

C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\base_layer.py in __call__(self, inputs, **kwargs)
    429                                          'You can build it manually via: '
    430                                          '`layer.build(batch_input_shape)`')
--> 431                 self.build(unpack_singleton(input_shapes))
    432                 self.built = True
    433 

C:\ProgramData\Anaconda3\lib\site-packages\keras\layers\core.py in build(self, input_shape)
    864                                       name='kernel',
    865                                       regularizer=self.kernel_regularizer,
--> 866                                       constraint=self.kernel_constraint)
    867         if self.use_bias:
    868             self.bias = self.add_weight(shape=(self.units,),

C:\ProgramData\Anaconda3\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name + '` call to the ' +
     90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\base_layer.py in add_weight(self, name, shape, dtype, initializer, regularizer, trainable, constraint)
    253         if regularizer is not None:
    254             with K.name_scope('weight_regularizer'):
--> 255                 self.add_loss(regularizer(weight))
    256         if trainable:
    257             self._trainable_weights.append(weight)

<ipython-input-4-9f4dfbe34659> in kl_divergence_regularizer(inputs)
     15         down = 0.05 * K.ones_like(means)
     16         up = (1 - 0.05) * K.ones_like(means)
---> 17         return 0.5 *(0.01 * (kullback_leibler_divergence(down, means)
     18                       + kullback_leibler_divergence(up, 1 - means)))
     19 

C:\ProgramData\Anaconda3\lib\site-packages\keras\losses.py in kullback_leibler_divergence(y_true, y_pred)
     81     y_true = K.clip(y_true, K.epsilon(), 1)
     82     y_pred = K.clip(y_pred, K.epsilon(), 1)
---> 83     return K.sum(y_true * K.log(y_true / y_pred), axis=-1)
     84 
     85 

C:\ProgramData\Anaconda3\lib\site-packages\keras\backend\tensorflow_backend.py in sum(x, axis, keepdims)
   1286         A tensor with sum of `x`.
   1287     """
-> 1288     return tf.reduce_sum(x, axis, keepdims)
   1289 
   1290 

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\ops\math_ops.py in reduce_sum_v1(input_tensor, axis, keepdims, name, reduction_indices, keep_dims)
   1284   keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
   1285                                                     "keep_dims", keep_dims)
-> 1286   return reduce_sum(input_tensor, axis, keepdims, name)
   1287 
   1288 

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\util\dispatch.py in wrapper(*args, **kwargs)
    178     """Call target, and fall back on dispatchers if there is a TypeError."""
    179     try:
--> 180       return target(*args, **kwargs)
    181     except (TypeError, ValueError):
    182       # Note: convert_to_eager_tensor currently raises a ValueError, not a

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\ops\math_ops.py in reduce_sum(input_tensor, axis, keepdims, name)
   1332       gen_math_ops._sum(
   1333           input_tensor, _ReductionDims(input_tensor, axis), keepdims,
-> 1334           name=name))
   1335 
   1336 

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\ops\gen_math_ops.py in _sum(input, axis, keep_dims, name)
   9607   _, _, _op = _op_def_lib._apply_op_helper(
   9608         "Sum", input=input, reduction_indices=axis, keep_dims=keep_dims,
-> 9609                name=name)
   9610   _result = _op.outputs[:]
   9611   _inputs_flat = _op.inputs

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    786         op = g.create_op(op_type_name, inputs, output_types, name=scope,
    787                          input_types=input_types, attrs=attr_protos,
--> 788                          op_def=op_def)
    789       return output_structure, op_def.is_stateful, op
    790 

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in create_op(***failed resolving arguments***)
   3298           input_types=input_types,
   3299           original_op=self._default_original_op,
-> 3300           op_def=op_def)
   3301       self._create_op_helper(ret, compute_device=compute_device)
   3302     return ret

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in __init__(self, node_def, g, inputs, output_types, control_inputs, input_types, original_op, op_def)
   1821           op_def, inputs, node_def.attr)
   1822       self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
-> 1823                                 control_input_ops)
   1824 
   1825     # Initialize self._outputs.

C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py in _create_c_op(graph, node_def, inputs, control_inputs)
   1660   except errors.InvalidArgumentError as e:
   1661     # Convert to ValueError for backwards compatibility.
-> 1662     raise ValueError(str(e))
   1663 
   1664   return c_op

ValueError: Invalid reduction dimension -1 for input with 0 dimensions. for 'dense_3/weight_regularizer/Sum' (op: 'Sum') with input shapes: [], [] and with computed input tensors: input[1] = <-1>.

How can I fix this? I need the KL divergence between 0.05 and mean calculate the following sum over i:

KL=sum(0.05*\\log(0.05/mean[i]))

In order to print means,

means = K.means((input), axis=1)
... 
means_ = sess.run(means, feed_dict={x:  , y:   })
print(means_) 

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM