简体   繁体   中英

Scikit-learn pipeline returns list of zeroes

I am not able to understand why I am getting this wrong pipeline output.

Pipeline code:

my_pipeline = Pipeline(steps=[ 
    ('imputer', SimpleImputer(strategy='median')),
    ('std_scaler', StandardScaler())
])

Real data:

real = [[0.02498, 0.0, 1.89, 0.0, 0.518, 6.54, 59.7, 6.2669, 1.0, 422.0, 15.9, 389.96, 8.65]]

The pipeline output that I want:

want = [[-0.44228927, -0.4898311 , -1.37640684, -0.27288841, -0.34321545, 0.36524574, -0.33092752,  1.20235683, -1.0016859 ,  0.05733231, -1.21003475,  0.38110555, -0.57309194]]

But after running the below code:

getting = my_pipeline.fit_transform(real)

I am getting:

[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

The problem

This is an expected behavior because you define the data as a list.

After the first step of the pipeline ie the SimpleImputer, the returned output is a numpy array with shape (1,13) .

si = SimpleImputer()
si_out = si.fit_transform(real)

si_out.shape
# (1, 13)

The returned (1,13) array is the problem here. This is because the StandardScaler , removes the mean and divides by the std each column. Thus, it "sees" 13 columns and the final output is all 0s since the means have been removed.

sc = StandardScaler()
sc.fit_transform(si_out)

returns

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

The solution

It seems that you have only one variable/feature named real . Just reshape it before fitting.

import numpy as np

real = np.array([[0.02498, 0.0, 1.89, 0.0, 0.518, 6.54, 59.7, 6.2669, 1.0, 422.0, 15.9, 389.96, 8.65]]).reshape(-1,1)

my_pipeline = Pipeline(steps=[ 
    ('imputer', SimpleImputer(strategy='median')),
    ('std_scaler', StandardScaler())
])
my_pipeline.fit_transform(real)

array([[-0.48677709],
       [-0.4869504 ],
       [-0.47383804],
       [-0.4869504 ],
       [-0.48335664],
       [-0.44157747],
       [-0.07276633],
       [-0.44347217],
       [-0.48001264],
       [ 2.44078289],
       [-0.37664007],
       [ 2.21849716],
       [-0.4269388 ]])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM