简体   繁体   中英

Tensorflow : Shapes must be equal rank, but are 4 and 1, for tf.norm

I'm trying to find the norm all the of filters in a conv2d layer. Please find the code for the same below

conv1 = tf.layers.conv2d(
                         inputs=input_layer,
                         filters=32,
                         strides=(1, 1),
                         kernel_size=[3, 3],
                         padding="valid",
                         activation=tf.nn.relu,
                         use_bias=True,
                         kernel_regularizer=tf.nn.l2_loss,
                         bias_regularizer=tf.nn.l2_loss,
                         name="conv1")

var = [v for v in tf.trainable_variables() if "conv1" in v.name]
print(tf.norm(var,axis=4))

Shapes must be equal rank, but are 4 and 1 From merging shape 0 with other shapes. for 'norm/packed' (op: 'Pack') with input shapes: [3,3,3,32], [32].

I have tried with multiple axis values from "None to 4" and none work. Can someone explain what is the problem and how can it be solved?

There are two bugs in your code. One is that var contains a list of tensors while tf.norm() expects just one tensor. Furthermore, the weights have a dimension of 4 and dimensions are numbered from 0, so the fourth dimension's axis is 3. This code (tested):

import tensorflow as tf

input_layer = tf.random_uniform( shape = ( 2, 10, 10, 2 ) )

conv1 = tf.layers.conv2d(
                         inputs=input_layer,
                         filters=32,
                         strides=(1, 1),
                         kernel_size=[3, 3],
                         padding="valid",
                         activation=tf.nn.relu,
                         use_bias=True,
                         kernel_regularizer=tf.nn.l2_loss,
                         bias_regularizer=tf.nn.l2_loss,
                         name="conv1")

var = [v for v in tf.trainable_variables() if "conv1" in v.name][ 0 ]
print( tf.norm( var, axis = 3 ) )

will output:

Tensor("norm/Squeeze:0", shape=(3, 3, 2), dtype=float32)

with no error.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM