繁体   English   中英

将权重矩阵的索引值设置为零

[英]Setting indices values of a weight matrix to zero

我正在使用 Keras 训练神经网络。 我想将权重矩阵中的选定权重设置为零。 但是,我似乎无法弄清楚如何做到这一点。 下面,我添加了一个独立的可运行示例,说明我想要做什么:

  1. 我创建了一个具有 26 个输入和 26 个输出(无隐藏层)的简单神经网络。
  2. 我将约束 function 添加到层,即kernel_constraint
  3. 这调用了 function weight_constraint ,它应该将某些权重矩阵值设置为零。 这应该发生在__call__ function 内部。

我尝试了几种方法,但我不知道如何设置索引。

 row_indices = [1, 2, 3]
 col_indices = [10, 11, 12]

在权重矩阵w中为零。

import keras
import numpy
import tensorflow as tf


class weight_constraint(tf.keras.constraints.Constraint):
    """Constrains weight tensors to be centered around `ref_value`."""

    def __init__(self):
        pass

    def __call__(self, w):
        shape = w.shape.as_list()
        row_indices = [1, 2, 3]
        col_indices = [10, 11, 12]
        indices = numpy.ravel_multi_index((row_indices, col_indices), dims=shape, order='C')
        # Question: how to set the values indicated by "indices" to zero?
        return w


full_model = keras.Sequential()
input_layer = keras.layers.Dense(units=26, input_shape=(26,),
                                 activation='tanh',
                                 kernel_initializer='zeros',
                                 bias_initializer='zeros',
                                 kernel_constraint=weight_constraint())

full_model.add(input_layer)
loss = keras.losses.MeanSquaredError()
full_model.compile('RMSprop', loss=loss)

inputs = numpy.random.random((100, 26))
targets = numpy.random.random((100, 26))

training_history = full_model.fit(inputs, targets, epochs=10)

您可以尝试使用tensor_scatter_nd_update根据行和列索引更新张量:

w = tf.random.uniform((26, 26), minval =1, maxval=4, dtype=tf.int32)
row_indices = tf.constant([1, 2, 3])
col_indices = tf.constant([10, 11, 12])
print(w)
w = tf.tensor_scatter_nd_update(w, tf.stack([row_indices, col_indices], axis=1), tf.repeat(tf.constant(0), tf.shape(row_indices)[0]))
print(w)

注意第二个 output 的零值:

tf.Tensor(
[[3 3 2 3 1 2 3 1 2 2 2 3 1 3 2 2 2 2 1 2 3 1 1 3 2 3]
 [1 3 2 3 1 1 1 2 1 2 3 3 3 2 2 2 3 1 2 1 1 1 3 2 1 2]
 [3 1 2 3 2 2 2 2 1 3 2 3 1 1 2 3 2 1 1 3 1 2 1 2 1 2]
 [1 3 2 1 3 1 3 3 3 2 2 3 1 2 3 1 2 2 2 2 3 2 1 2 1 2]
 [1 1 2 3 3 1 3 3 1 2 1 3 2 2 1 3 3 2 1 3 1 3 2 2 1 3]
 [2 1 2 3 1 1 1 1 2 3 3 2 1 3 1 1 3 3 3 1 2 3 1 2 2 3]
 [2 1 2 3 1 3 2 2 2 3 1 2 3 3 2 1 1 3 1 2 1 2 1 3 2 1]
 [2 3 2 2 1 2 3 2 3 3 2 3 3 2 3 2 2 3 3 3 1 1 3 1 2 1]
 [1 2 1 2 1 1 1 3 2 3 3 1 2 1 1 1 1 2 3 3 2 3 2 2 1 3]
 [1 3 1 2 3 1 1 3 3 3 3 3 2 1 2 2 3 1 1 3 1 3 2 1 1 1]
 [3 2 2 2 1 3 3 2 3 1 2 3 2 2 2 1 1 2 3 3 3 2 3 3 3 1]
 [3 3 3 2 3 2 1 1 3 1 3 2 3 3 1 3 2 2 2 3 3 2 1 1 2 2]
 [3 3 3 1 3 3 2 2 1 3 3 2 1 2 1 1 1 2 3 3 1 1 1 2 2 2]
 [1 2 2 3 3 2 1 1 1 1 2 1 2 2 1 2 2 3 3 3 3 1 2 1 2 2]
 [3 2 3 3 2 1 2 3 2 2 2 1 1 1 3 3 1 1 3 3 3 2 1 2 3 2]
 [2 3 1 1 1 1 3 3 1 3 2 3 1 3 3 2 3 2 3 1 1 3 1 3 2 3]
 [2 3 2 2 3 1 2 2 3 1 3 3 1 2 1 3 3 2 2 1 2 1 3 2 2 3]
 [3 1 3 1 1 2 3 1 1 1 3 2 2 3 2 1 1 1 1 2 1 1 3 2 3 3]
 [1 2 2 1 1 2 1 3 2 1 3 3 1 3 3 2 1 2 1 1 1 1 2 2 1 2]
 [1 1 3 2 2 2 2 2 1 3 3 3 1 2 3 2 1 1 3 2 2 3 2 1 2 1]
 [3 2 2 3 3 3 1 3 3 3 3 2 2 1 2 3 3 2 1 1 3 1 2 3 1 1]
 [3 1 2 1 1 2 1 1 3 3 1 1 2 2 1 3 2 2 3 3 3 3 3 2 3 2]
 [2 2 3 3 1 3 2 1 2 3 2 2 2 3 1 3 2 1 1 1 2 2 3 1 3 1]
 [2 2 2 3 1 3 2 2 2 3 1 2 1 2 2 3 2 1 3 1 2 1 3 3 2 2]
 [2 2 2 2 1 2 3 2 1 1 3 2 1 3 3 1 2 3 1 3 2 2 1 2 2 1]
 [3 3 2 2 3 3 3 2 2 1 2 2 3 2 2 2 2 3 2 3 3 1 3 1 1 1]], shape=(26, 26), dtype=int32)
tf.Tensor(
[[3 3 2 3 1 2 3 1 2 2 2 3 1 3 2 2 2 2 1 2 3 1 1 3 2 3]
 [1 3 2 3 1 1 1 2 1 2 0 3 3 2 2 2 3 1 2 1 1 1 3 2 1 2]
 [3 1 2 3 2 2 2 2 1 3 2 0 1 1 2 3 2 1 1 3 1 2 1 2 1 2]
 [1 3 2 1 3 1 3 3 3 2 2 3 0 2 3 1 2 2 2 2 3 2 1 2 1 2]
 [1 1 2 3 3 1 3 3 1 2 1 3 2 2 1 3 3 2 1 3 1 3 2 2 1 3]
 [2 1 2 3 1 1 1 1 2 3 3 2 1 3 1 1 3 3 3 1 2 3 1 2 2 3]
 [2 1 2 3 1 3 2 2 2 3 1 2 3 3 2 1 1 3 1 2 1 2 1 3 2 1]
 [2 3 2 2 1 2 3 2 3 3 2 3 3 2 3 2 2 3 3 3 1 1 3 1 2 1]
 [1 2 1 2 1 1 1 3 2 3 3 1 2 1 1 1 1 2 3 3 2 3 2 2 1 3]
 [1 3 1 2 3 1 1 3 3 3 3 3 2 1 2 2 3 1 1 3 1 3 2 1 1 1]
 [3 2 2 2 1 3 3 2 3 1 2 3 2 2 2 1 1 2 3 3 3 2 3 3 3 1]
 [3 3 3 2 3 2 1 1 3 1 3 2 3 3 1 3 2 2 2 3 3 2 1 1 2 2]
 [3 3 3 1 3 3 2 2 1 3 3 2 1 2 1 1 1 2 3 3 1 1 1 2 2 2]
 [1 2 2 3 3 2 1 1 1 1 2 1 2 2 1 2 2 3 3 3 3 1 2 1 2 2]
 [3 2 3 3 2 1 2 3 2 2 2 1 1 1 3 3 1 1 3 3 3 2 1 2 3 2]
 [2 3 1 1 1 1 3 3 1 3 2 3 1 3 3 2 3 2 3 1 1 3 1 3 2 3]
 [2 3 2 2 3 1 2 2 3 1 3 3 1 2 1 3 3 2 2 1 2 1 3 2 2 3]
 [3 1 3 1 1 2 3 1 1 1 3 2 2 3 2 1 1 1 1 2 1 1 3 2 3 3]
 [1 2 2 1 1 2 1 3 2 1 3 3 1 3 3 2 1 2 1 1 1 1 2 2 1 2]
 [1 1 3 2 2 2 2 2 1 3 3 3 1 2 3 2 1 1 3 2 2 3 2 1 2 1]
 [3 2 2 3 3 3 1 3 3 3 3 2 2 1 2 3 3 2 1 1 3 1 2 3 1 1]
 [3 1 2 1 1 2 1 1 3 3 1 1 2 2 1 3 2 2 3 3 3 3 3 2 3 2]
 [2 2 3 3 1 3 2 1 2 3 2 2 2 3 1 3 2 1 1 1 2 2 3 1 3 1]
 [2 2 2 3 1 3 2 2 2 3 1 2 1 2 2 3 2 1 3 1 2 1 3 3 2 2]
 [2 2 2 2 1 2 3 2 1 1 3 2 1 3 3 1 2 3 1 3 2 2 1 2 2 1]
 [3 3 2 2 3 3 3 2 2 1 2 2 3 2 2 2 2 3 2 3 3 1 3 1 1 1]], shape=(26, 26), dtype=int32)

因此,您的代码将如下所示:

class weight_constraint(tf.keras.constraints.Constraint):
    """Constrains weight tensors to be centered around `ref_value`."""

    def __init__(self):
        pass

    def __call__(self, w):
        row_indices = tf.constant([1, 2, 3])
        col_indices = tf.constant([10, 11, 12])
        w = tf.tensor_scatter_nd_update(w, tf.stack([row_indices, col_indices], axis=1), tf.repeat(tf.constant(0.0), tf.shape(row_indices)[0]))
        return w

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM