简体   繁体   中英

valueError while training with tensorflow while running model.Evaluate()

Code

All of the 4 columns are float64. I'm not sure what to do about this error as I've looked at similar stack overflow issues and nothing regarding converting it to numpy array and/or float32 seem to solve this.

This code is base on: Google Colab

I just replaced the housing data with mlb pitching data and applied dropna() to the train and test dataframes.

Thank you.

In the function train_model() , specify the array .astype(float) .

So the line you need to edit:

features = {name:np.array(value) for name, value in dataset.items()}

... change to:

features = {name:np.array(value).astype(float) for name, value in dataset.items()}

You'll also need to do it in the next block as well:

test_features = {name:np.array(value).astype(float) for name, value in cut_test_df_norm.items()}

Full function code:

def train_model(model, dataset, epochs, label_name,
                batch_size=None):
  """Train the model by feeding it data."""

  # Split the dataset into features and label.
  features = {name:np.array(value).astype(float) for name, value in dataset.items()}
  label = np.array(features.pop(label_name))
  history = model.fit(x=features, y=label, batch_size=batch_size,
                      epochs=epochs, shuffle=True) 
  
  print(features)

  # The list of epochs is stored separately from the rest of history.
  epochs = history.epoch
  
  # To track the progression of training, gather a snapshot
  # of the model's mean squared error at each epoch. 
  hist = pd.DataFrame(history.history)
  mse = hist["mean_squared_error"]

  return epochs, mse

And:

# After building a model against the training set, test that model
# against the test set.
test_features = {name:np.array(value).astype(float) for name, value in cut_test_df_norm.items()}
test_label = np.array(test_features.pop(label_name)) # isolate the label
print("\n Evaluate the new model against the test set:")
my_model.evaluate(x = test_features, y = test_label, batch_size=batch_size)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM