简体   繁体   中英

Tensorflow model.fit ValueError

I'm trying to train a model using Tensorflow. I'm reading a huge csv file using tf.data.experimental.make_csv_dataset Here's my code:

Imports

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing

LABEL_COLUMN = 'venda_qtde'

Reading a csv into a tf.data.Dataset

def get_dataset(file_path, **kwargs):
  dataset = tf.data.experimental.make_csv_dataset(
      file_path,
      batch_size=4096, 
      na_value="?",
      label_name=LABEL_COLUMN,
      num_epochs=1,
      ignore_errors=False,
      shuffle=False,
      **kwargs)
  return dataset

Buiding a model instance

def build_model():
  model = None
  model = keras.Sequential([
    layers.Dense(520, activation='relu'),
    layers.Dense(520, activation='relu'),
    layers.Dense(520, activation='relu'),
    layers.Dense(1)
  ])

  model.compile(loss='mean_squared_error',
                optimizer='adam',
                metrics=['mae'])
  return model

Executing the functions:

ds_treino = get_dataset('data/processed/curva_a/curva_a_train.csv')
nn_model = build_model()
nn_model.fit(ds_treino, epochs=10)

But when the fit function is called, i get the error:

ValueError: Layer sequential_5 expects 1 inputs, but it received 520 input tensors. Inputs received: ...

My dataset has 519 features and 1 label and about 17M lines Can anyone help me what I'm doing wrong?

Function make_csv_dataset will return a tf.data.Dataset that features are a dictionary.

train_dataset
<PrefetchDataset shapes: (OrderedDict([... features ... ])

You need to pair them into features and labels. You can use:

def features_and_labels(features, labels):
  features = tf.stack(list(features.values()), axis=1)
  return features, labels

train_dataset = train_dataset.map(features_and_labels)

train_dataset
<MapDataset shapes: ((None, 10), (None,)), types: (tf.float32, tf.int32)>

After that, you should be able to pass it into fit() function.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM