The problem: I'm loading a simple VGG16 from a saved checkpoint. I want to generate the saliency for an image during inference. When i compute the gradients (of loss wrt input image) required for this, i get back all gradients as zero. Any ideas as to what I'm missing here is much appreciated!
tf version: tensorflow-2.0alpha-gpu
The model:
import tensorflow as tf
from tensorflow.keras.applications.vgg16 import VGG16 as KerasVGG16
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Flatten, Dense
class VGG16(Model):
def __init__(self, num_classes, use_pretrained=True):
super(VGG16, self).__init__()
self.num_classes = num_classes
self.use_pretrained = use_pretrained
if use_pretrained:
self.base_model = KerasVGG16(weights='imagenet', include_top=False)
for layer in self.base_model.layers:
layer.trainable = False
else:
self.base_model = KerasVGG16(include_top=False)
self.flatten1 = Flatten(name='flatten')
self.dense1 = Dense(4096, activation='relu', name='fc1')
self.dense2 = Dense(100, activation='relu', name='fc2')
self.dense3 = Dense(self.num_classes, activation='softmax', name='predictions')
def call(self, inputs):
x = self.base_model(tf.cast(inputs, tf.float32))
x = self.flatten1(x)
x = self.dense1(x)
x = self.dense2(x)
x = self.dense3(x)
return x
I train this model and save it to a checkpoint and load it back via:
model = VGG16(num_classes=2, use_pretrained=False)
checkpoint = tf.train.Checkpoint(net=model)
status = checkpoint.restore(tf.train.latest_checkpoint('./my_checkpoint'))
status.assert_consumed()
I verify the weights are correctly loaded.
Get a test image
# load my image and make sure its float
img = tf.convert_to_tensor(image, dtype=tf.float64)
support_class = tf.convert_to_tensor(support_class, dtype=tf.float64)
Get the gradients:
with tf.GradientTape(persistent=True) as g_tape:
g_tape.watch(img)
#g_tape.watch(model.base_model.trainable_variables)
#g_tape.watch(model.trainable_variables)
loss = tf.losses.CategoricalCrossentropy()(support_class, model(img))
gradients_wrt_image = g_tape.gradient(loss,
img, unconnected_gradients=tf.UnconnectedGradients.NONE)
When i inspect my gradients they are all zero! Any idea what am i missing? Thanks in advance!
The gradients are not zero, although they are very small:
def almost_equals(a, b, decimal=6):
try:
np.testing.assert_almost_equal(a, b, decimal=decimal)
except AssertionError:
return False
return True
image = [abs(np.random.normal(size=(32, 32, 3))) for _ in range(20)]
label = [[0, 1] if i % 3 == 0 else [1, 0] for i in range(20)]
img = tf.convert_to_tensor(image, dtype=tf.float64)
support_class = tf.convert_to_tensor(label, dtype=tf.float64)
loss_fn = tf.losses.CategoricalCrossentropy()
with tf.GradientTape(persistent=True) as tape:
tape.watch(img)
softmaxed = model(img)
loss = loss_fn(support_class, softmaxed)
grads = tape.gradient(loss, img, unconnected_gradients=tf.UnconnectedGradients.NONE)
# summing up all gradients with reduction over all dimension:
print(tf.reduce_sum(grads, axis=None).numpy()) # 0.07137820225818814
# comparing to zeros:
zeros_like_grads = np.zeros_like(grads.numpy())
for decimal in range(10, 0, -1):
print('decimal: {0}: {1}'.format(decimal,
almost_equals(zeros_like_grads,
grads.numpy(),
decimal=decimal)))
# decimal: 10: False
# decimal: 9: False
# decimal: 8: False
# decimal: 7: False
# decimal: 6: False
# decimal: 5: False
# decimal: 4: False
# decimal: 3: True
# decimal: 2: True
# decimal: 1: True
As you can see, only starting at decimal=3
it starts to return True
.
So, it turns out there is nothing wrong with the network. The problem is related the behavior of the softmax activation that i use in my final Dense
layer. I didn't consider the fact that very confident predictions from the softmax (eg one of my predictions [[1.0000000e+00 1.9507678e-25]]) would make gradients zero (theoretically very close to zero but practically, zero). A useful thread that discuss this and how to counter it: https://github.com/keras-team/keras/issues/5881
My solution : turn off the softmax activation when i want to compute gradients wrt input image
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.