简体   繁体   中英

Setting indices values of a weight matrix to zero

I am training a neural network using Keras. I want to set selected weights in a weight matrix to zero. However, I can't seem to figure out how to do this. Below, I have added a self-contained runnable example of what I want to do:

  1. I create a simple neural network with 26 inputs and 26 outputs (no hidden layers).
  2. I add a constraint function to the layers, ie, kernel_constraint .
  3. This calls the function weight_constraint which is supposed to set certain weight matrix values to zero. This should happen inside the __call__ function.

I've tried several approaches, but I can't get my mind around how to set the indices.

 row_indices = [1, 2, 3]
 col_indices = [10, 11, 12]

to zero in the weight matrix w .

import keras
import numpy
import tensorflow as tf


class weight_constraint(tf.keras.constraints.Constraint):
    """Constrains weight tensors to be centered around `ref_value`."""

    def __init__(self):
        pass

    def __call__(self, w):
        shape = w.shape.as_list()
        row_indices = [1, 2, 3]
        col_indices = [10, 11, 12]
        indices = numpy.ravel_multi_index((row_indices, col_indices), dims=shape, order='C')
        # Question: how to set the values indicated by "indices" to zero?
        return w


full_model = keras.Sequential()
input_layer = keras.layers.Dense(units=26, input_shape=(26,),
                                 activation='tanh',
                                 kernel_initializer='zeros',
                                 bias_initializer='zeros',
                                 kernel_constraint=weight_constraint())

full_model.add(input_layer)
loss = keras.losses.MeanSquaredError()
full_model.compile('RMSprop', loss=loss)

inputs = numpy.random.random((100, 26))
targets = numpy.random.random((100, 26))

training_history = full_model.fit(inputs, targets, epochs=10)

You can try using tensor_scatter_nd_update to update a tensor based on the row and column indices:

w = tf.random.uniform((26, 26), minval =1, maxval=4, dtype=tf.int32)
row_indices = tf.constant([1, 2, 3])
col_indices = tf.constant([10, 11, 12])
print(w)
w = tf.tensor_scatter_nd_update(w, tf.stack([row_indices, col_indices], axis=1), tf.repeat(tf.constant(0), tf.shape(row_indices)[0]))
print(w)

Notice the zero values at the second output:

tf.Tensor(
[[3 3 2 3 1 2 3 1 2 2 2 3 1 3 2 2 2 2 1 2 3 1 1 3 2 3]
 [1 3 2 3 1 1 1 2 1 2 3 3 3 2 2 2 3 1 2 1 1 1 3 2 1 2]
 [3 1 2 3 2 2 2 2 1 3 2 3 1 1 2 3 2 1 1 3 1 2 1 2 1 2]
 [1 3 2 1 3 1 3 3 3 2 2 3 1 2 3 1 2 2 2 2 3 2 1 2 1 2]
 [1 1 2 3 3 1 3 3 1 2 1 3 2 2 1 3 3 2 1 3 1 3 2 2 1 3]
 [2 1 2 3 1 1 1 1 2 3 3 2 1 3 1 1 3 3 3 1 2 3 1 2 2 3]
 [2 1 2 3 1 3 2 2 2 3 1 2 3 3 2 1 1 3 1 2 1 2 1 3 2 1]
 [2 3 2 2 1 2 3 2 3 3 2 3 3 2 3 2 2 3 3 3 1 1 3 1 2 1]
 [1 2 1 2 1 1 1 3 2 3 3 1 2 1 1 1 1 2 3 3 2 3 2 2 1 3]
 [1 3 1 2 3 1 1 3 3 3 3 3 2 1 2 2 3 1 1 3 1 3 2 1 1 1]
 [3 2 2 2 1 3 3 2 3 1 2 3 2 2 2 1 1 2 3 3 3 2 3 3 3 1]
 [3 3 3 2 3 2 1 1 3 1 3 2 3 3 1 3 2 2 2 3 3 2 1 1 2 2]
 [3 3 3 1 3 3 2 2 1 3 3 2 1 2 1 1 1 2 3 3 1 1 1 2 2 2]
 [1 2 2 3 3 2 1 1 1 1 2 1 2 2 1 2 2 3 3 3 3 1 2 1 2 2]
 [3 2 3 3 2 1 2 3 2 2 2 1 1 1 3 3 1 1 3 3 3 2 1 2 3 2]
 [2 3 1 1 1 1 3 3 1 3 2 3 1 3 3 2 3 2 3 1 1 3 1 3 2 3]
 [2 3 2 2 3 1 2 2 3 1 3 3 1 2 1 3 3 2 2 1 2 1 3 2 2 3]
 [3 1 3 1 1 2 3 1 1 1 3 2 2 3 2 1 1 1 1 2 1 1 3 2 3 3]
 [1 2 2 1 1 2 1 3 2 1 3 3 1 3 3 2 1 2 1 1 1 1 2 2 1 2]
 [1 1 3 2 2 2 2 2 1 3 3 3 1 2 3 2 1 1 3 2 2 3 2 1 2 1]
 [3 2 2 3 3 3 1 3 3 3 3 2 2 1 2 3 3 2 1 1 3 1 2 3 1 1]
 [3 1 2 1 1 2 1 1 3 3 1 1 2 2 1 3 2 2 3 3 3 3 3 2 3 2]
 [2 2 3 3 1 3 2 1 2 3 2 2 2 3 1 3 2 1 1 1 2 2 3 1 3 1]
 [2 2 2 3 1 3 2 2 2 3 1 2 1 2 2 3 2 1 3 1 2 1 3 3 2 2]
 [2 2 2 2 1 2 3 2 1 1 3 2 1 3 3 1 2 3 1 3 2 2 1 2 2 1]
 [3 3 2 2 3 3 3 2 2 1 2 2 3 2 2 2 2 3 2 3 3 1 3 1 1 1]], shape=(26, 26), dtype=int32)
tf.Tensor(
[[3 3 2 3 1 2 3 1 2 2 2 3 1 3 2 2 2 2 1 2 3 1 1 3 2 3]
 [1 3 2 3 1 1 1 2 1 2 0 3 3 2 2 2 3 1 2 1 1 1 3 2 1 2]
 [3 1 2 3 2 2 2 2 1 3 2 0 1 1 2 3 2 1 1 3 1 2 1 2 1 2]
 [1 3 2 1 3 1 3 3 3 2 2 3 0 2 3 1 2 2 2 2 3 2 1 2 1 2]
 [1 1 2 3 3 1 3 3 1 2 1 3 2 2 1 3 3 2 1 3 1 3 2 2 1 3]
 [2 1 2 3 1 1 1 1 2 3 3 2 1 3 1 1 3 3 3 1 2 3 1 2 2 3]
 [2 1 2 3 1 3 2 2 2 3 1 2 3 3 2 1 1 3 1 2 1 2 1 3 2 1]
 [2 3 2 2 1 2 3 2 3 3 2 3 3 2 3 2 2 3 3 3 1 1 3 1 2 1]
 [1 2 1 2 1 1 1 3 2 3 3 1 2 1 1 1 1 2 3 3 2 3 2 2 1 3]
 [1 3 1 2 3 1 1 3 3 3 3 3 2 1 2 2 3 1 1 3 1 3 2 1 1 1]
 [3 2 2 2 1 3 3 2 3 1 2 3 2 2 2 1 1 2 3 3 3 2 3 3 3 1]
 [3 3 3 2 3 2 1 1 3 1 3 2 3 3 1 3 2 2 2 3 3 2 1 1 2 2]
 [3 3 3 1 3 3 2 2 1 3 3 2 1 2 1 1 1 2 3 3 1 1 1 2 2 2]
 [1 2 2 3 3 2 1 1 1 1 2 1 2 2 1 2 2 3 3 3 3 1 2 1 2 2]
 [3 2 3 3 2 1 2 3 2 2 2 1 1 1 3 3 1 1 3 3 3 2 1 2 3 2]
 [2 3 1 1 1 1 3 3 1 3 2 3 1 3 3 2 3 2 3 1 1 3 1 3 2 3]
 [2 3 2 2 3 1 2 2 3 1 3 3 1 2 1 3 3 2 2 1 2 1 3 2 2 3]
 [3 1 3 1 1 2 3 1 1 1 3 2 2 3 2 1 1 1 1 2 1 1 3 2 3 3]
 [1 2 2 1 1 2 1 3 2 1 3 3 1 3 3 2 1 2 1 1 1 1 2 2 1 2]
 [1 1 3 2 2 2 2 2 1 3 3 3 1 2 3 2 1 1 3 2 2 3 2 1 2 1]
 [3 2 2 3 3 3 1 3 3 3 3 2 2 1 2 3 3 2 1 1 3 1 2 3 1 1]
 [3 1 2 1 1 2 1 1 3 3 1 1 2 2 1 3 2 2 3 3 3 3 3 2 3 2]
 [2 2 3 3 1 3 2 1 2 3 2 2 2 3 1 3 2 1 1 1 2 2 3 1 3 1]
 [2 2 2 3 1 3 2 2 2 3 1 2 1 2 2 3 2 1 3 1 2 1 3 3 2 2]
 [2 2 2 2 1 2 3 2 1 1 3 2 1 3 3 1 2 3 1 3 2 2 1 2 2 1]
 [3 3 2 2 3 3 3 2 2 1 2 2 3 2 2 2 2 3 2 3 3 1 3 1 1 1]], shape=(26, 26), dtype=int32)

So, your code would look something like this:

class weight_constraint(tf.keras.constraints.Constraint):
    """Constrains weight tensors to be centered around `ref_value`."""

    def __init__(self):
        pass

    def __call__(self, w):
        row_indices = tf.constant([1, 2, 3])
        col_indices = tf.constant([10, 11, 12])
        w = tf.tensor_scatter_nd_update(w, tf.stack([row_indices, col_indices], axis=1), tf.repeat(tf.constant(0.0), tf.shape(row_indices)[0]))
        return w

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM