I am trying to implement normalizing flows embedded in a Keras model. In all examples I can find, such as the documentation of MAF , the bijectors which constitute the normalizing flows are embedded into a TransformedDistribution
and exposed directly for training etc.
I am trying to embed this TransformedDistribution in a keras Model to match the architecture of other models I have which are inheriting from keras Model.
Unfortunately all my attempts (see code) so far fail at transferring the trainable variables inside the transformed distribution to the keras Model.
I have tried to make the bijector inherit from tf.keras.layers.Layer
, which did not change anything.
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
class Flow(tfb.Bijector, tf.Module):
"""
tf.Module to register trainable_variables
"""
def __init__(self, d, init_sigma=0.1, **kwargs):
super(Flow, self).__init__(
dtype=tf.float32,
forward_min_event_ndims=0,
inverse_min_event_ndims=0,
**kwargs
)
# Shape of the flow goes from Rd to Rd
self.d = d
# Weights/Variables initializer
self.init_sigma = init_sigma
w_init = tf.random_normal_initializer(stddev=self.init_sigma)
# Variables
self.u = tf.Variable(
w_init(shape=[1, self.d], dtype=tf.float32),
dtype=tf.float32,
name='u',
trainable=True,
)
def _forward(self, x):
return x
def _inverse(self, y):
return y
class Flows(tf.keras.Model):
def __init__(self, d=2, shape=(100, 2), n_flows=10, ):
super(Flows, self).__init__()
# Parameters
self.d = d
self.shape = shape
self.n_flows = n_flows
# Base distribution - MF = Multivariate normal diag
base_distribution = tfd.MultivariateNormalDiag(
loc=tf.zeros(shape=shape, dtype=tf.float32)
)
# Flows as chain of bijector
flows = []
for n in range(n_flows):
flows.append(Flow(self.d, name=f"flow_{n + 1}"))
bijector = tfb.Chain(list(reversed(flows)))
self.flow = tfd.TransformedDistribution(
distribution=base_distribution,
bijector=bijector
)
def call(self, *inputs):
return self.flow.bijector.forward(*inputs)
def log_prob(self, *inputs):
return self.flow.log_prob(*inputs)
def sample(self, num):
return self.flow.sample(num)
q = Flows()
# Call to instantiate variables
q(tf.zeros(q.shape))
# Prints no trainable params
print(q.summary())
# Prints expected trainable params
print(q.flow.trainable_variables)
Any idea if this is even possible? Thanks!
I bumped into this issue as well. It seems to be caused by the incompatibility issues between TFP and TF 2.0 (a couple relevant issues https://github.com/tensorflow/probability/issues/355 and https://github.com/tensorflow/probability/issues/946 ).
As a workaround, you need to add the (trainable) variables of your transformed distribution / bijector as an attribute to your Keras Model:
class Flows(tf.keras.Model):
def __init__(self, d=2, shape=(100, 2), n_flows=10, ):
super(Flows, self).__init__()
# Parameters
self.d = d
self.shape = shape
self.n_flows = n_flows
# Base distribution - MF = Multivariate normal diag
base_distribution = tfd.MultivariateNormalDiag(
loc=tf.zeros(shape=shape, dtype=tf.float32)
)
# Flows as chain of bijector
flows = []
for n in range(n_flows):
flows.append(Flow(self.d, name=f"flow_{n + 1}"))
bijector = tfb.Chain(list(reversed(flows)))
self.flow = tfd.TransformedDistribution(
distribution=base_distribution,
bijector=bijector
)
# issue: https://github.com/tensorflow/probability/issues/355, https://github.com/tensorflow/probability/issues/946
# need to add bijector's trainable variables as an attribute (name does not matter)
# otherwise this layer has zero trainable variables
self._variables = self.flow.variables # https://github.com/tensorflow/probability/issues/355
def call(self, *inputs):
return self.flow.bijector.forward(*inputs)
def log_prob(self, *inputs):
return self.flow.log_prob(*inputs)
def sample(self, num):
return self.flow.sample(num)
After adding this your model should have trainable variables and weights to optimize.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.