简体   繁体   中英

Train for a parameter in the weight matrix in Tensorflow

I have a neural network. For simplicity, there's only one layer and the weight matrix is of shape 2-by-2 . I need the output of the network to be the rotated version of the input, ie, the matrix should be a valid rotation matrix. I have tried the following:

def rotate(val):
    w1 = tf.constant_initializer([[cos45, -sin45], [sin45, cos45]])
    return tf.layers.dense(inputs=val, units=2, kernel_initializer=w1, activation=tf.nn.tanh) 

While training, I do not want to lose properties of the rotation matrix. In other words, I need the layer(s) to estimate only the angle (argument) of trigonometric functions in the matrix.

I read that kernel_constraint can help in this aspect, by normalizing the values. But applying kernel_constraint does not guarantee diagonal entries being equal and the off diagonal entries being negatives of each other (in this case). In general, the two properties that need to be satisfied are, the determinant should be 1 and R^T*R = I .

Is there any other way to achieve this?

You could define yourcustom Keras layer . Something along the lines of:

from tensorflow.keras.layers import Layer
import tensorflow as tf

class Rotate(Layer):
    def build(self, input_shape):
        sh = input_shape[0]
        shape = [sh, sh]

        # Initial weight matrix
        w = self.add_weight(shape=shape,
                            initializer='random_uniform')

        # Set upper diagonal elements to negative of lower diagonal elements
        mask = tf.cast(tf.linalg.band_part(tf.ones(shape), -1, 0), tf.float32)
        w = mask * w
        w -= tf.transpose(w)

        # Set the same weight to the diagonal
        diag_mask = 1 - tf.linalg.diag(tf.ones(sh))
        w = diag_mask * w
        diag_w = self.add_weight(shape=(1,),
                                 initializer='random_uniform')
        diagonal = tf.linalg.diag(tf.ones(sh)) * diag_w
        self.kernel = w + diagonal

    def call(self, inputs, **kwargs):
        return tf.matmul(inputs, self.kernel)

Note that the matrix of learnable weights self.kernel has this aspect: [[D, -L], [L, D]]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM