简体   繁体   English

我错过了什么吗? TensorFlow 中的简单分类器输入 function 出错

[英]Am I missing something? Error with simple classifier input function in TensorFlow

I've been following along with the freecodecamp tutorial on TensorFlow and I've tried to modify a basic classifier to handle one of my own structured datasets.我一直在关注 TensorFlow 上的 freecodecamp 教程,我试图修改一个基本分类器来处理我自己的一个结构化数据集。

I have a training dataset and a testing dataset, each one containing some integers and some strings.我有一个训练数据集和一个测试数据集,每个数据集都包含一些整数和一些字符串。 I'm trying to predict the value in the allocated column, but it keeps throwing this error when the Classifier.train method is called:我正在尝试预测分配列中的值,但在调用 Classifier.train 方法时它会不断抛出此错误:

UnimplementedError: Cast string to float is not supported
     [[{{node head/losses/Cast}}]]

During handling of the above exception, another exception occurred:

UnimplementedError                        Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
   1392                     '\nsession_config.graph_options.rewrite_options.'
   1393                     'disable_meta_optimizer = True')
-> 1394       raise type(e)(node_def, op, message)  # pylint: disable=no-value-for-parameter
   1395 
   1396   def _extend_graph(self):

UnimplementedError: Cast string to float is not supported
     [[node head/losses/Cast (defined at /usr/local/lib/python3.7/dist-packages/tensorflow_estimator/python/estimator/head/binary_class_head.py:255) ]]

I've tried converting the dataset so that all of the values are integers or floats, but I keep getting the same error.我试过转换数据集,使所有值都是整数或浮点数,但我一直收到同样的错误。 From what I've been able to figure out, the classifier should be able to operate on different datatypes, so I can't see why that would be the issue unless I need to define them somewhere?据我所知,分类器应该能够对不同的数据类型进行操作,所以我不明白为什么这会成为问题,除非我需要在某处定义它们?

I know that it's reading in the data correctly, because when I use the.head() function it's all properly formatted.我知道它正在正确读取数据,因为当我使用 the.head() function 时,它的格式都正确。 I've been stuck on this error for days and I can't figure out what I'm missing.几天来我一直被这个错误困住,我不知道我错过了什么。 Any help would be greatly appreciated.任何帮助将不胜感激。 My code is below.我的代码如下。

%tensorflow_version 2.x 

from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import clear_output
from six.moves import urllib
import tensorflow.compat.v2.feature_column as fc
import tensorflow as tf

CSV_COLUMN_NAMES = ['GroupNumber', 'GroupUnit', 'GroupSkill1', 'GroupSkill2', 'GroupSkill3', 'GroupSkill4', 'GroupPreference1', 
                'GroupPreference2', 'GroupPreference3', 'ProjectNumber', 'ProjectUnit', 'ProjectSkill1', 'ProjectSkill2', 'ProjectSkill3', 'ProjectSkill4', 'ProjectPreference1', 'ProjectPreference2', 'ProjectPreference3', 'Allocated']
ALLOCATED = [0, 1]

train = pd.read_csv('https://raw.githubusercontent.com/nickjackson862/machine-learning/main/trainData40_10.csv', names=CSV_COLUMN_NAMES, header=0)
test = pd.read_csv('https://raw.githubusercontent.com/nickjackson862/machine-learning/main/testData40_10.csv', names=CSV_COLUMN_NAMES, header=0)

train_y = train.pop('Allocated')
test_y = test.pop('Allocated')
train.head()


def input_fn(features, labels, training=True, batch_size=100):   
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))    
    if training:
        dataset = dataset.shuffle(10).repeat()    
    return dataset.batch(batch_size)

my_feature_columns = []
for key in train.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[50, 20],
    n_classes=2)

classifier.train(
    input_fn=lambda: input_fn(train, train_y, training=True),
    steps=100)

eval_result = classifier.evaluate(
    input_fn=lambda: input_fn(test, test_y, training=False))

print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))

I found the problem in this line where you are creating your feature columns.我在创建特征列的这一行中发现了问题。

my_feature_columns.append(tf.feature_column.numeric_column(key=key))

You are making every feature a numeric feature, but looking at your dataset several of your fields are strings (that CSV file is public btw, you may want to remedy that).您正在使每个特征成为数字特征,但查看您的数据集时,您的几个字段是字符串(顺便说一句,CSV 文件是公共文件,您可能需要补救)。

I've tried converting the dataset so that all of the values are integers or floats, but I keep getting the same error.我试过转换数据集,使所有值都是整数或浮点数,但我一直收到同样的错误。

I believe you did this incorrectly.我相信你做错了。 I just tried running your code but with all the string type columns removed and it ran successfully with no errors.我只是尝试运行您的代码,但删除了所有字符串类型的列,并且它成功运行且没有错误。 All I did was add the following lines after reading in the CSVs我所做的只是在阅读 CSV 后添加以下行

train.drop(columns=['GroupSkill1', 'GroupSkill2', 'GroupSkill3', 'GroupSkill4', "ProjectSkill1", "ProjectSkill2", "ProjectSkill3", "ProjectSkill4", ], axis=1, inplace=True)
test.drop(columns=['GroupSkill1', 'GroupSkill2', 'GroupSkill3', 'GroupSkill4', "ProjectSkill1", "ProjectSkill2", "ProjectSkill3", "ProjectSkill4", ], axis=1, inplace=True)

Check out this article for advice for creating feature columns for your non-numeric data: https://www.tensorflow.org/tutorials/structured_data/feature_columns查看本文以获取有关为非数字数据创建特征列的建议: https://www.tensorflow.org/tutorials/structured_data/feature_columns

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 find 函数在 python 中出现故障或不工作,还是我遗漏了什么? - find function in python malfunctioning or not working, or am I missing something? 这是一个错误还是我错过了什么? - Is this a bug or am I missing something? Selenium Python“名称错误”我肯定错过了一些明显的东西 - Selenium Python "Name Error" I am definitely missing something obvious 我正在尝试对列表进行分段,然后根据给定的输入值删除重复项? 但是我错过了一些东西 - I am trying to segment the list and then remove duplicates according to the given input values ? However I am missing something 简单列表处理函数中的IndexError - 我缺少什么? - IndexError in simple list processing function - what am I missing? Python 日期比较没有给出正确的结果,我是不是漏掉了一些超级简单的东西? - Python date comparison not giving correct result, am I missing something super simple? 多处理/线程-高估了,或者我缺少什么? - Multiprocessing/Threading - Overrated, or am I missing something? 我是否在 django 上的模板上遗漏了什么 - Am i missing something on templates on django 我是否缺少缩进或其他问题? - Am I missing indentation or is something else wrong with this? 有没有其他方法可以做到这一点,或者我错过了什么 - is there a another way to do this or i am missing something
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM