简体   繁体   中英

How to “zip” Tensorflow Dataset and train in Keras correctly?

I have a train_x.csv and a train_y.csv , and I'd like to train a model using Dataset API and Keras interface. This what I'm trying to do:

import numpy as np
import pandas as pd
import tensorflow as tf

tf.enable_eager_execution()

N_FEATURES = 10
N_SAMPLES = 100
N_OUTPUTS = 2
BATCH_SIZE = 8
EPOCHS = 5

# prepare fake data
train_x = pd.DataFrame(np.random.rand(N_SAMPLES, N_FEATURES))
train_x.to_csv('train_x.csv', index=False)
train_y = pd.DataFrame(np.random.rand(N_SAMPLES, N_OUTPUTS))
train_y.to_csv('train_y.csv', index=False)

train_x = tf.data.experimental.CsvDataset('train_x.csv', [tf.float32] * N_FEATURES, header=True)
train_y = tf.data.experimental.CsvDataset('train_y.csv', [tf.float32] * N_OUTPUTS, header=True)
dataset = ...  # What to do here?

model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(N_OUTPUTS, input_shape=(N_FEATURES,)),
    tf.keras.layers.Activation('linear'),
])
model.compile('sgd', 'mse')
model.fit(dataset, steps_per_epoch=N_SAMPLES/BATCH_SIZE, epochs=EPOCHS)

What's the right way to implement this dataset ?

I tried Dataset.zip API like dataset = tf.data.Dataset.zip((train_x, train_y)) but it seems not working(code here and error here ). I also read this answer, it's working but I'd like a non-functional model declaration way.

The problem is in the input shape of your dense layer. It should match shape of your input tensor, which is 1. tf.keras.layers.Dense(N_OUTPUTS, input_shape=(features_shape,))

Also you might encounter problems defining model.fit() steps_per_epoch parameter , it should be of type int . model.fit(dataset, steps_per_epoch=int(N_SAMPLES/BATCH_SIZE), epochs=EPOCHS)

Edit 1: In case you need multiple labels, you can do

def parse_f(data, labels):
    return data, tf.stack(labels, axis=0)

dataset = tf.data.Dataset.zip((train_x, train_y))
dataset = dataset.map(parse_func)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.repeat()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM