I've been trying to train a custom style transfer net with AdaIN. The only problem I'm facing now is the gradients I'm getting are all NaN values, right from the first epoch. Currently using TF 2.6.1.
Here's the custom training loop and loss function:
def _compute_mean_std(self, feats : tf.Tensor, eps=1e-8):
"""
feats: Features should be in shape N x H x W x C
"""
mean = tf.math.reduce_mean(feats, axis=[1,2], keepdims=True)
std = tf.math.reduce_std(feats, axis=[1,2], keepdims=True) + eps
return mean, std
def criterion(self, stylized_img : tf.Tensor, style_img : tf.Tensor, t : tf.Tensor):
stylized_content_feats = self.model.encode(stylized_img)
stylized_feats = self.model.encode(stylized_img, return_all=True)
style_feats = self.model.encode(style_img, return_all=True)
content_loss = self.mse_loss(t, stylized_content_feats)
style_loss = 0
for f1, f2 in zip(stylized_feats, style_feats):
m1, s1 = self._compute_mean_std(f1)
m2, s2 = self._compute_mean_std(f2)
style_loss += self.mse_loss(m1, m2) + self.mse_loss(s1, s2)
return content_loss + self.style_weight * style_loss
def train(self):
step = 0
while step < self.num_iter:
content_batch = self.content_iter.get_next()
if content_batch.shape[0] != self.batch_size:
content_batch = self.content_iter.get_next()
style_batch = self.style_iter.get_next()
if style_batch.shape[0] != self.batch_size:
style_batch = self.style_iter.get_next()
with tf.GradientTape() as tape:
stylized_imgs, t = self.model(dict(content_imgs=content_batch, style_imgs=style_batch, alpha=1.0))
loss = self.criterion(stylized_imgs, style_batch, t)
gradients = tape.gradient(loss, self.model.trainable_weights)
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_weights))
# log and save every 200 batches
if step % 200 == 0:
print(f'Training loss (for one batch) at step {step}: {loss}')
print(f'Seen so far: {(step+1)*self.batch_size} samples')
self.model.save_weights(f'./checkpoints/adain_e{step}.ckpt')
step += 1
print("Finished training...")
self.model.save_weights('saved_model/adain_weights.h5')
I can't figure out why it's doing that. It doesn't throw an error when _compute_mean_std
calculates mean/std in the format NxCxHxW though, which is not what I want either. Adding a transpose also causes this when trying to calculate it in the correct shape.
It is possible this way
[ Sample ]:
import os
from os.path import exists
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
None
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
config = tf.config.experimental.set_memory_growth(physical_devices[0], True)
print(physical_devices)
print(config)
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Variables
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
num_iter = 1000
train_generator_batch_size = 1
batch_size = 1
WIDTH = 256
HEIGHT = 256
CHANNEL = 3
checkpoint_path = "F:\\models\\checkpoint\\" + os.path.basename(__file__).split('.')[0] + "\\TF_DataSets_01.h5"
checkpoint_dir = os.path.dirname(checkpoint_path)
if not exists(checkpoint_dir) :
os.mkdir(checkpoint_dir)
print("Create directory: " + checkpoint_dir)
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Definition / Class
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
def create_image_generator( ):
variables = pd.read_excel('F:\\temp\\Python\\excel\\Book 7.xlsx', index_col=None, header=[0], dtype=str)
train_generator = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
validation_split=0.2,
)
train_image_ds = train_generator.flow_from_dataframe(
dataframe = variables,
directory=None,
x_col= 'Image',
y_col= 'Label',
weight_col=None,
target_size=( WIDTH, HEIGHT ),
color_mode='rgb',
classes=None,
class_mode='categorical', ####
batch_size=train_generator_batch_size,
shuffle=True,
seed=None,
save_to_dir=None,
save_prefix='',
save_format='png',
subset=None,
interpolation='nearest',
validate_filenames=True,
)
return train_image_ds
class gradient_tape_optimizer( ):
def __init__ ( self, model, num_iter, content_iter, batch_size ):
self.num_iter = num_iter
self.content_iter = content_iter
self.style_iter = content_iter
self.batch_size = batch_size
self.model = model
self.loss = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=False,
reduction=tf.keras.losses.Reduction.AUTO,
name='sparse_categorical_crossentropy' )
self.optimizer = tf.keras.optimizers.Nadam( learning_rate=0.00001, beta_1=0.9, beta_2=0.999, epsilon=1e-07, name='Nadam' )
def _compute_mean_std( self, feats : tf.Tensor, eps=1e-8 ):
"""
feats: Features should be in shape N x H x W x C
"""
mean = tf.math.reduce_mean(feats, axis=[1,2], keepdims=True)
std = tf.math.reduce_std(feats, axis=[1,2], keepdims=True) + eps
return mean, std
def criterion( self, stylized_img : tf.Tensor, style_img : tf.Tensor, t : tf.Tensor ):
stylized_content_feats = self.model.encode(stylized_img)
stylized_feats = self.model.encode(stylized_img, return_all=True)
style_feats = self.model.encode(style_img, return_all=True)
content_loss = self.mse_loss(t, stylized_content_feats)
style_loss = 0
for f1, f2 in zip(stylized_feats, style_feats):
m1, s1 = self._compute_mean_std(f1)
m2, s2 = self._compute_mean_std(f2)
style_loss += self.mse_loss(m1, m2) + self.mse_loss(s1, s2)
return content_loss + self.style_weight * style_loss
def train( self ):
step = 0
while step < self.num_iter:
content_batch = self.content_iter.get_next()
if content_batch[0].shape[1] != self.batch_size:
content_batch = self.content_iter.get_next()
style_batch = self.style_iter.get_next()
if style_batch[0].shape[1] != self.batch_size:
style_batch = self.style_iter.get_next()
current_label = tf.constant( content_batch[1], shape=( 2, 1 ) ).numpy()
loss_value = tf.Variable( 10.0 )
with tf.GradientTape() as tape:
result = self.model( inputs=tf.constant( content_batch[0], shape=( 1, WIDTH, HEIGHT, CHANNEL ) ) )
result = tf.constant( result, shape=( 2, 1 ) )
predict_label = tf.Variable( tf.constant( self.model.trainable_weights[len(self.model.trainable_weights) - 1], shape=( 2, 1 ) ) )
loss_value = self.loss( result.numpy(), current_label )
loss_value = tf.Variable( tf.constant( loss_value, shape=( 1, ) ).numpy() )
tape.watch( loss_value )
gradients = tape.gradient( loss_value, loss_value )
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_weights))
# log and save every 200 batches
if step % 200 == 0:
print(f'Training loss (for one batch) at step {step}: {self.loss}')
print(f'Seen so far: {(step+1)*self.batch_size} samples')
self.model.save_weights(checkpoint_path)
step += 1
print("Finished training...")
self.model.save_weights(checkpoint_path)
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Dataset
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
variables = pd.read_excel('F:\\temp\\Python\\excel\\Book 7.xlsx', index_col=None, header=[0], dtype=str)
train_image_ds = tf.data.Dataset.from_generator(
create_image_generator,
output_types=None,
output_shapes=None,
args=None,
output_signature=(
tf.TensorSpec(shape=( 1, WIDTH, HEIGHT, CHANNEL ), dtype=tf.float32, name=None), tf.TensorSpec(shape=(1, 2), dtype=tf.float32, name=None),
),
name='train_image_ds'
)
train_image_ds = train_image_ds.batch( 1 )
iterator = iter( train_image_ds )
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Initialize
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=( WIDTH, HEIGHT, CHANNEL )),
tf.keras.layers.Normalization(mean=3., variance=2.),
tf.keras.layers.Normalization(mean=4., variance=6.),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Reshape((128, 127 * 127)),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(96, return_sequences=True, return_state=False)),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(96)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(192, activation='relu'),
tf.keras.layers.Dense(2),
])
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Optimizer
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
optimizer = tf.keras.optimizers.Nadam(
learning_rate=0.00001, beta_1=0.9, beta_2=0.999, epsilon=0.0000001,
name='Nadam'
)
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Loss Fn
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
lossfn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=False,
reduction=tf.keras.losses.Reduction.AUTO,
name='sparse_categorical_crossentropy'
)
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Summary
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model.compile(optimizer=optimizer, loss=lossfn, metrics=['accuracy'])
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Training
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
gradient_tape_optimizer = gradient_tape_optimizer( model, num_iter, iterator, batch_size )
result = gradient_tape_optimizer.train()
input( '...' )
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.