I am training a neural network using Keras. I want to set selected weights in a weight matrix to zero. However, I can't seem to figure out how to do this. Below, I have added a self-contained runnable example of what I want to do:
kernel_constraint
.weight_constraint
which is supposed to set certain weight matrix values to zero. This should happen inside the __call__
function.I've tried several approaches, but I can't get my mind around how to set the indices.
row_indices = [1, 2, 3]
col_indices = [10, 11, 12]
to zero in the weight matrix w
.
import keras
import numpy
import tensorflow as tf
class weight_constraint(tf.keras.constraints.Constraint):
"""Constrains weight tensors to be centered around `ref_value`."""
def __init__(self):
pass
def __call__(self, w):
shape = w.shape.as_list()
row_indices = [1, 2, 3]
col_indices = [10, 11, 12]
indices = numpy.ravel_multi_index((row_indices, col_indices), dims=shape, order='C')
# Question: how to set the values indicated by "indices" to zero?
return w
full_model = keras.Sequential()
input_layer = keras.layers.Dense(units=26, input_shape=(26,),
activation='tanh',
kernel_initializer='zeros',
bias_initializer='zeros',
kernel_constraint=weight_constraint())
full_model.add(input_layer)
loss = keras.losses.MeanSquaredError()
full_model.compile('RMSprop', loss=loss)
inputs = numpy.random.random((100, 26))
targets = numpy.random.random((100, 26))
training_history = full_model.fit(inputs, targets, epochs=10)
You can try using tensor_scatter_nd_update
to update a tensor based on the row and column indices:
w = tf.random.uniform((26, 26), minval =1, maxval=4, dtype=tf.int32)
row_indices = tf.constant([1, 2, 3])
col_indices = tf.constant([10, 11, 12])
print(w)
w = tf.tensor_scatter_nd_update(w, tf.stack([row_indices, col_indices], axis=1), tf.repeat(tf.constant(0), tf.shape(row_indices)[0]))
print(w)
Notice the zero values at the second output:
tf.Tensor(
[[3 3 2 3 1 2 3 1 2 2 2 3 1 3 2 2 2 2 1 2 3 1 1 3 2 3]
[1 3 2 3 1 1 1 2 1 2 3 3 3 2 2 2 3 1 2 1 1 1 3 2 1 2]
[3 1 2 3 2 2 2 2 1 3 2 3 1 1 2 3 2 1 1 3 1 2 1 2 1 2]
[1 3 2 1 3 1 3 3 3 2 2 3 1 2 3 1 2 2 2 2 3 2 1 2 1 2]
[1 1 2 3 3 1 3 3 1 2 1 3 2 2 1 3 3 2 1 3 1 3 2 2 1 3]
[2 1 2 3 1 1 1 1 2 3 3 2 1 3 1 1 3 3 3 1 2 3 1 2 2 3]
[2 1 2 3 1 3 2 2 2 3 1 2 3 3 2 1 1 3 1 2 1 2 1 3 2 1]
[2 3 2 2 1 2 3 2 3 3 2 3 3 2 3 2 2 3 3 3 1 1 3 1 2 1]
[1 2 1 2 1 1 1 3 2 3 3 1 2 1 1 1 1 2 3 3 2 3 2 2 1 3]
[1 3 1 2 3 1 1 3 3 3 3 3 2 1 2 2 3 1 1 3 1 3 2 1 1 1]
[3 2 2 2 1 3 3 2 3 1 2 3 2 2 2 1 1 2 3 3 3 2 3 3 3 1]
[3 3 3 2 3 2 1 1 3 1 3 2 3 3 1 3 2 2 2 3 3 2 1 1 2 2]
[3 3 3 1 3 3 2 2 1 3 3 2 1 2 1 1 1 2 3 3 1 1 1 2 2 2]
[1 2 2 3 3 2 1 1 1 1 2 1 2 2 1 2 2 3 3 3 3 1 2 1 2 2]
[3 2 3 3 2 1 2 3 2 2 2 1 1 1 3 3 1 1 3 3 3 2 1 2 3 2]
[2 3 1 1 1 1 3 3 1 3 2 3 1 3 3 2 3 2 3 1 1 3 1 3 2 3]
[2 3 2 2 3 1 2 2 3 1 3 3 1 2 1 3 3 2 2 1 2 1 3 2 2 3]
[3 1 3 1 1 2 3 1 1 1 3 2 2 3 2 1 1 1 1 2 1 1 3 2 3 3]
[1 2 2 1 1 2 1 3 2 1 3 3 1 3 3 2 1 2 1 1 1 1 2 2 1 2]
[1 1 3 2 2 2 2 2 1 3 3 3 1 2 3 2 1 1 3 2 2 3 2 1 2 1]
[3 2 2 3 3 3 1 3 3 3 3 2 2 1 2 3 3 2 1 1 3 1 2 3 1 1]
[3 1 2 1 1 2 1 1 3 3 1 1 2 2 1 3 2 2 3 3 3 3 3 2 3 2]
[2 2 3 3 1 3 2 1 2 3 2 2 2 3 1 3 2 1 1 1 2 2 3 1 3 1]
[2 2 2 3 1 3 2 2 2 3 1 2 1 2 2 3 2 1 3 1 2 1 3 3 2 2]
[2 2 2 2 1 2 3 2 1 1 3 2 1 3 3 1 2 3 1 3 2 2 1 2 2 1]
[3 3 2 2 3 3 3 2 2 1 2 2 3 2 2 2 2 3 2 3 3 1 3 1 1 1]], shape=(26, 26), dtype=int32)
tf.Tensor(
[[3 3 2 3 1 2 3 1 2 2 2 3 1 3 2 2 2 2 1 2 3 1 1 3 2 3]
[1 3 2 3 1 1 1 2 1 2 0 3 3 2 2 2 3 1 2 1 1 1 3 2 1 2]
[3 1 2 3 2 2 2 2 1 3 2 0 1 1 2 3 2 1 1 3 1 2 1 2 1 2]
[1 3 2 1 3 1 3 3 3 2 2 3 0 2 3 1 2 2 2 2 3 2 1 2 1 2]
[1 1 2 3 3 1 3 3 1 2 1 3 2 2 1 3 3 2 1 3 1 3 2 2 1 3]
[2 1 2 3 1 1 1 1 2 3 3 2 1 3 1 1 3 3 3 1 2 3 1 2 2 3]
[2 1 2 3 1 3 2 2 2 3 1 2 3 3 2 1 1 3 1 2 1 2 1 3 2 1]
[2 3 2 2 1 2 3 2 3 3 2 3 3 2 3 2 2 3 3 3 1 1 3 1 2 1]
[1 2 1 2 1 1 1 3 2 3 3 1 2 1 1 1 1 2 3 3 2 3 2 2 1 3]
[1 3 1 2 3 1 1 3 3 3 3 3 2 1 2 2 3 1 1 3 1 3 2 1 1 1]
[3 2 2 2 1 3 3 2 3 1 2 3 2 2 2 1 1 2 3 3 3 2 3 3 3 1]
[3 3 3 2 3 2 1 1 3 1 3 2 3 3 1 3 2 2 2 3 3 2 1 1 2 2]
[3 3 3 1 3 3 2 2 1 3 3 2 1 2 1 1 1 2 3 3 1 1 1 2 2 2]
[1 2 2 3 3 2 1 1 1 1 2 1 2 2 1 2 2 3 3 3 3 1 2 1 2 2]
[3 2 3 3 2 1 2 3 2 2 2 1 1 1 3 3 1 1 3 3 3 2 1 2 3 2]
[2 3 1 1 1 1 3 3 1 3 2 3 1 3 3 2 3 2 3 1 1 3 1 3 2 3]
[2 3 2 2 3 1 2 2 3 1 3 3 1 2 1 3 3 2 2 1 2 1 3 2 2 3]
[3 1 3 1 1 2 3 1 1 1 3 2 2 3 2 1 1 1 1 2 1 1 3 2 3 3]
[1 2 2 1 1 2 1 3 2 1 3 3 1 3 3 2 1 2 1 1 1 1 2 2 1 2]
[1 1 3 2 2 2 2 2 1 3 3 3 1 2 3 2 1 1 3 2 2 3 2 1 2 1]
[3 2 2 3 3 3 1 3 3 3 3 2 2 1 2 3 3 2 1 1 3 1 2 3 1 1]
[3 1 2 1 1 2 1 1 3 3 1 1 2 2 1 3 2 2 3 3 3 3 3 2 3 2]
[2 2 3 3 1 3 2 1 2 3 2 2 2 3 1 3 2 1 1 1 2 2 3 1 3 1]
[2 2 2 3 1 3 2 2 2 3 1 2 1 2 2 3 2 1 3 1 2 1 3 3 2 2]
[2 2 2 2 1 2 3 2 1 1 3 2 1 3 3 1 2 3 1 3 2 2 1 2 2 1]
[3 3 2 2 3 3 3 2 2 1 2 2 3 2 2 2 2 3 2 3 3 1 3 1 1 1]], shape=(26, 26), dtype=int32)
So, your code would look something like this:
class weight_constraint(tf.keras.constraints.Constraint):
"""Constrains weight tensors to be centered around `ref_value`."""
def __init__(self):
pass
def __call__(self, w):
row_indices = tf.constant([1, 2, 3])
col_indices = tf.constant([10, 11, 12])
w = tf.tensor_scatter_nd_update(w, tf.stack([row_indices, col_indices], axis=1), tf.repeat(tf.constant(0.0), tf.shape(row_indices)[0]))
return w
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.