简体   繁体   中英

NaN giving ValueError in OneHotEncoder in scikit-learn

Here is my code

import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder

train = pd.DataFrame({
        'users':['John Johnson','John Smith','Mary Williams']
})
test = pd.DataFrame({
        'users':[None,np.nan,'John Smith','Mary Williams']
})

ohe = OneHotEncoder(sparse=False,handle_unknown='ignore')
ohe.fit(train)
train_transformed = ohe.fit_transform(train)

test_transformed = ohe.transform(test)
print(test_transformed)

I expected the OneHotEncoder to be able to handle the np.nan in the test dataset, since

handle_unknown='ignore'

But it gives ValueError. It is able to handle the None value though. Why is it failing?And how do I get around it (besides Imputer)?

From the documentation ( https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html ) it seemed that this was what handle_unknown is for.

You must empute missing values first. handle_unknown='ignore' doesn't concerne NaN values but new categories not fitted in ohe .

You can consider NaNs as a distinct category as follow:

train = train.fillna("NaN")
test = test.fillna("NaN")

I don't know the purpose of the function but running the code and looking at the result tells me the following:

The ValueError you are receiving tells you, that the function is expecting numerical data - not strings.

Also note that the 'handle_unknown' flag does not mean, that the function takes None or nan values but rather manages how to handle categories in the test data that were not present in the training data (see example below).

The following code including unknown categories is working:

import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder

train = pd.DataFrame({
        'users':[1,2,3,4],'users2':[1,2,3,4]
})
test = pd.DataFrame({
        'users':[0,1,3,4,10],'users3':[1,2,3,4,5]
})

ohe = OneHotEncoder(sparse=False,handle_unknown='ignore')
ohe.fit(train)
train_transformed = ohe.fit_transform(train)

test_transformed = ohe.transform(test)
print(test_transformed)

Hope that helps. Replacing the missing data works analogous to what was suggested by the previous answer.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM