I have a Keras model with 1 input and 2 outputs in TensorFlow 2. When calling model.fit
I want to pass dataset as x=train_dataset
and call model.fit
once. The train_dataset
is made with tf.data.Dataset.from_generator
which yields: x1, y1, y2.
The only way I can run training is the following:
for x1, y1,y2 in train_dataset:
model.fit(x=x1, y=[y1,y2],...)
How to tell TensorFlow to unpack variables and train without the explicit for
loop? Using the for
loop makes many things less practical, as well as usage of train_on_batch
.
If I want to run model.fit(train_dataset, ...)
the function doesn't understand what x
and y
are, even the model is defined like:
model = Model(name ='Joined_Model',inputs=self.x, outputs=[self.network.y1, self.network.y2])
It throws an error that it is expecting 2 targets while getting 1, even the dataset has 3 variables, which can be iterated trough in the loop.
The dataset and mini-batch are generated as:
def dataset_joined(self, n_epochs, buffer_size=32):
dataset = tf.data.Dataset.from_generator(
self.mbatch_gen_joined,
(tf.float32, tf.float32,tf.int32),
(tf.TensorShape([None, None, self.n_feat]),
tf.TensorShape([None, None, self.n_feat]),
tf.TensorShape([None, None])),
[tf.constant(n_epochs)]
)
dataset = dataset.prefetch(buffer_size)
return dataset
def mbatch_gen_joined(self, n_epochs):
for _ in range(n_epochs):
random.shuffle(self.train_s_list)
start_idx, end_idx = 0, self.mbatch_size
for _ in range(self.n_iter):
s_mbatch_list = self.train_s_list[start_idx:end_idx]
d_mbatch_list = random.sample(self.train_d_list, end_idx-start_idx)
s_mbatch, d_mbatch, s_mbatch_len, d_mbatch_len, snr_mbatch, label_mbatch, _ = \
self.wav_batch(s_mbatch_list, d_mbatch_list)
x_STMS_mbatch, xi_bar_mbatch, _ = \
self.training_example(s_mbatch, d_mbatch, s_mbatch_len,
d_mbatch_len, snr_mbatch)
#seq_mask_mbatch = tf.cast(tf.sequence_mask(n_frames_mbatch), tf.float32)
start_idx += self.mbatch_size; end_idx += self.mbatch_size
if end_idx > self.n_examples: end_idx = self.n_examples
yield x_STMS_mbatch, xi_bar_mbatch, label_mbatch
Keras models expect the Python generators or tf.data.Dataset
objects provide the input data as a tuple with the format of (input_data, target_data)
(or (input_data, target_data, sample_weights)
). Each of input_data
or target_data
could and should be a list/tuple if the model has multiple input/output layers. Therefore, in your code, the generated data should also be compatible with this expected format:
yield x_STMS_mbatch, (xi_bar_mbatch, label_mbatch) # <- the second element is a tuple itself
Also, this should be considered in the arguments passed to the from_generator
method as well:
dataset = tf.data.Dataset.from_generator(
self.mbatch_gen_joined,
output_types=(
tf.float32,
(tf.float32, tf.int32)
),
output_shapes=(
tf.TensorShape([None, None, self.n_feat]),
(
tf.TensorShape([None, None, self.n_feat]),
tf.TensorShape([None, None])
)
),
args=(tf.constant(n_epochs),)
)
Use yield(x1, [y1,y2])
so model.fit will understand your generator output.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.