[英]Setting indices values of a weight matrix to zero
我正在使用 Keras 训练神经网络。 我想将权重矩阵中的选定权重设置为零。 但是,我似乎无法弄清楚如何做到这一点。 下面,我添加了一个独立的可运行示例,说明我想要做什么:
kernel_constraint
。weight_constraint
,它应该将某些权重矩阵值设置为零。 这应该发生在__call__
function 内部。我尝试了几种方法,但我不知道如何设置索引。
row_indices = [1, 2, 3]
col_indices = [10, 11, 12]
在权重矩阵w
中为零。
import keras
import numpy
import tensorflow as tf
class weight_constraint(tf.keras.constraints.Constraint):
"""Constrains weight tensors to be centered around `ref_value`."""
def __init__(self):
pass
def __call__(self, w):
shape = w.shape.as_list()
row_indices = [1, 2, 3]
col_indices = [10, 11, 12]
indices = numpy.ravel_multi_index((row_indices, col_indices), dims=shape, order='C')
# Question: how to set the values indicated by "indices" to zero?
return w
full_model = keras.Sequential()
input_layer = keras.layers.Dense(units=26, input_shape=(26,),
activation='tanh',
kernel_initializer='zeros',
bias_initializer='zeros',
kernel_constraint=weight_constraint())
full_model.add(input_layer)
loss = keras.losses.MeanSquaredError()
full_model.compile('RMSprop', loss=loss)
inputs = numpy.random.random((100, 26))
targets = numpy.random.random((100, 26))
training_history = full_model.fit(inputs, targets, epochs=10)
您可以尝试使用tensor_scatter_nd_update
根据行和列索引更新张量:
w = tf.random.uniform((26, 26), minval =1, maxval=4, dtype=tf.int32)
row_indices = tf.constant([1, 2, 3])
col_indices = tf.constant([10, 11, 12])
print(w)
w = tf.tensor_scatter_nd_update(w, tf.stack([row_indices, col_indices], axis=1), tf.repeat(tf.constant(0), tf.shape(row_indices)[0]))
print(w)
注意第二个 output 的零值:
tf.Tensor(
[[3 3 2 3 1 2 3 1 2 2 2 3 1 3 2 2 2 2 1 2 3 1 1 3 2 3]
[1 3 2 3 1 1 1 2 1 2 3 3 3 2 2 2 3 1 2 1 1 1 3 2 1 2]
[3 1 2 3 2 2 2 2 1 3 2 3 1 1 2 3 2 1 1 3 1 2 1 2 1 2]
[1 3 2 1 3 1 3 3 3 2 2 3 1 2 3 1 2 2 2 2 3 2 1 2 1 2]
[1 1 2 3 3 1 3 3 1 2 1 3 2 2 1 3 3 2 1 3 1 3 2 2 1 3]
[2 1 2 3 1 1 1 1 2 3 3 2 1 3 1 1 3 3 3 1 2 3 1 2 2 3]
[2 1 2 3 1 3 2 2 2 3 1 2 3 3 2 1 1 3 1 2 1 2 1 3 2 1]
[2 3 2 2 1 2 3 2 3 3 2 3 3 2 3 2 2 3 3 3 1 1 3 1 2 1]
[1 2 1 2 1 1 1 3 2 3 3 1 2 1 1 1 1 2 3 3 2 3 2 2 1 3]
[1 3 1 2 3 1 1 3 3 3 3 3 2 1 2 2 3 1 1 3 1 3 2 1 1 1]
[3 2 2 2 1 3 3 2 3 1 2 3 2 2 2 1 1 2 3 3 3 2 3 3 3 1]
[3 3 3 2 3 2 1 1 3 1 3 2 3 3 1 3 2 2 2 3 3 2 1 1 2 2]
[3 3 3 1 3 3 2 2 1 3 3 2 1 2 1 1 1 2 3 3 1 1 1 2 2 2]
[1 2 2 3 3 2 1 1 1 1 2 1 2 2 1 2 2 3 3 3 3 1 2 1 2 2]
[3 2 3 3 2 1 2 3 2 2 2 1 1 1 3 3 1 1 3 3 3 2 1 2 3 2]
[2 3 1 1 1 1 3 3 1 3 2 3 1 3 3 2 3 2 3 1 1 3 1 3 2 3]
[2 3 2 2 3 1 2 2 3 1 3 3 1 2 1 3 3 2 2 1 2 1 3 2 2 3]
[3 1 3 1 1 2 3 1 1 1 3 2 2 3 2 1 1 1 1 2 1 1 3 2 3 3]
[1 2 2 1 1 2 1 3 2 1 3 3 1 3 3 2 1 2 1 1 1 1 2 2 1 2]
[1 1 3 2 2 2 2 2 1 3 3 3 1 2 3 2 1 1 3 2 2 3 2 1 2 1]
[3 2 2 3 3 3 1 3 3 3 3 2 2 1 2 3 3 2 1 1 3 1 2 3 1 1]
[3 1 2 1 1 2 1 1 3 3 1 1 2 2 1 3 2 2 3 3 3 3 3 2 3 2]
[2 2 3 3 1 3 2 1 2 3 2 2 2 3 1 3 2 1 1 1 2 2 3 1 3 1]
[2 2 2 3 1 3 2 2 2 3 1 2 1 2 2 3 2 1 3 1 2 1 3 3 2 2]
[2 2 2 2 1 2 3 2 1 1 3 2 1 3 3 1 2 3 1 3 2 2 1 2 2 1]
[3 3 2 2 3 3 3 2 2 1 2 2 3 2 2 2 2 3 2 3 3 1 3 1 1 1]], shape=(26, 26), dtype=int32)
tf.Tensor(
[[3 3 2 3 1 2 3 1 2 2 2 3 1 3 2 2 2 2 1 2 3 1 1 3 2 3]
[1 3 2 3 1 1 1 2 1 2 0 3 3 2 2 2 3 1 2 1 1 1 3 2 1 2]
[3 1 2 3 2 2 2 2 1 3 2 0 1 1 2 3 2 1 1 3 1 2 1 2 1 2]
[1 3 2 1 3 1 3 3 3 2 2 3 0 2 3 1 2 2 2 2 3 2 1 2 1 2]
[1 1 2 3 3 1 3 3 1 2 1 3 2 2 1 3 3 2 1 3 1 3 2 2 1 3]
[2 1 2 3 1 1 1 1 2 3 3 2 1 3 1 1 3 3 3 1 2 3 1 2 2 3]
[2 1 2 3 1 3 2 2 2 3 1 2 3 3 2 1 1 3 1 2 1 2 1 3 2 1]
[2 3 2 2 1 2 3 2 3 3 2 3 3 2 3 2 2 3 3 3 1 1 3 1 2 1]
[1 2 1 2 1 1 1 3 2 3 3 1 2 1 1 1 1 2 3 3 2 3 2 2 1 3]
[1 3 1 2 3 1 1 3 3 3 3 3 2 1 2 2 3 1 1 3 1 3 2 1 1 1]
[3 2 2 2 1 3 3 2 3 1 2 3 2 2 2 1 1 2 3 3 3 2 3 3 3 1]
[3 3 3 2 3 2 1 1 3 1 3 2 3 3 1 3 2 2 2 3 3 2 1 1 2 2]
[3 3 3 1 3 3 2 2 1 3 3 2 1 2 1 1 1 2 3 3 1 1 1 2 2 2]
[1 2 2 3 3 2 1 1 1 1 2 1 2 2 1 2 2 3 3 3 3 1 2 1 2 2]
[3 2 3 3 2 1 2 3 2 2 2 1 1 1 3 3 1 1 3 3 3 2 1 2 3 2]
[2 3 1 1 1 1 3 3 1 3 2 3 1 3 3 2 3 2 3 1 1 3 1 3 2 3]
[2 3 2 2 3 1 2 2 3 1 3 3 1 2 1 3 3 2 2 1 2 1 3 2 2 3]
[3 1 3 1 1 2 3 1 1 1 3 2 2 3 2 1 1 1 1 2 1 1 3 2 3 3]
[1 2 2 1 1 2 1 3 2 1 3 3 1 3 3 2 1 2 1 1 1 1 2 2 1 2]
[1 1 3 2 2 2 2 2 1 3 3 3 1 2 3 2 1 1 3 2 2 3 2 1 2 1]
[3 2 2 3 3 3 1 3 3 3 3 2 2 1 2 3 3 2 1 1 3 1 2 3 1 1]
[3 1 2 1 1 2 1 1 3 3 1 1 2 2 1 3 2 2 3 3 3 3 3 2 3 2]
[2 2 3 3 1 3 2 1 2 3 2 2 2 3 1 3 2 1 1 1 2 2 3 1 3 1]
[2 2 2 3 1 3 2 2 2 3 1 2 1 2 2 3 2 1 3 1 2 1 3 3 2 2]
[2 2 2 2 1 2 3 2 1 1 3 2 1 3 3 1 2 3 1 3 2 2 1 2 2 1]
[3 3 2 2 3 3 3 2 2 1 2 2 3 2 2 2 2 3 2 3 3 1 3 1 1 1]], shape=(26, 26), dtype=int32)
因此,您的代码将如下所示:
class weight_constraint(tf.keras.constraints.Constraint):
"""Constrains weight tensors to be centered around `ref_value`."""
def __init__(self):
pass
def __call__(self, w):
row_indices = tf.constant([1, 2, 3])
col_indices = tf.constant([10, 11, 12])
w = tf.tensor_scatter_nd_update(w, tf.stack([row_indices, col_indices], axis=1), tf.repeat(tf.constant(0.0), tf.shape(row_indices)[0]))
return w
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.