繁体   English   中英

使用 Pickle 加载保存的 model - 在加载的程序中完成 fit_transform 时出现错误

[英]Loading saved model using Pickle - getting error as fit_transform is done in loaded program

我创建了第一个程序来训练算法并保存它。

程序 1

import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit 
from sklearn.impute import SimpleImputer  
from sklearn.tree import DecisionTreeRegressor # import for Decision Tree Algorithm
import pickle
from sklearn.preprocessing import StandardScaler

SourceData=pd.read_excel("ASML Stock Predict.xlsx") # Load the data into Pandas DataFrame
SourceData["Nasdaq Category"]=pd.cut(SourceData["Adj Close Nasdaq 100"],
                                     bins=[0., 4500, 5500, 6500, 7500,8500, 9500, 10500, np.inf],
                                     labels=[1, 2, 3, 4,5,6,7,8])

""" Split the data source into test and train subset """
split = StratifiedShuffleSplit(n_splits=1, test_size=0.01, random_state=42)
for train_index, test_index in split.split(SourceData, SourceData["Nasdaq Category"]):
    strat_train_set = SourceData.loc[train_index]  # stratfied train dataset with all columns in original source data 
    strat_test_set = SourceData.loc[test_index] #stratified test dataset with all columns in original source data

""" Drop the new Nasdaq Category Cloumn from the data source after the train and test subset is prepared"""
for set_ in (strat_train_set, strat_test_set): 
    set_.drop("Nasdaq Category", axis=1, inplace=True)

DataSource_train_independent= strat_train_set.drop(["Date", "Adj Close ASML"], axis=1) # Drop depedent variable from training dataset
DataSource_train_dependent=strat_train_set["Adj Close ASML"].copy() #  New dataframe with only independent variable value for training dataset



imputer = SimpleImputer(strategy="median") # declated imputer to fill the blank values with Median value of the variable
imputer.fit(DataSource_train_independent) # calulate the median for different independent variables

""" Scale the independent variables training set. No need to scale the dependent variable """
sc_X = StandardScaler()
X=sc_X.fit_transform(DataSource_train_independent.values) # scale the independent variables
X_test=sc_X.transform(testdata.values) # scale the independent variables for test data
##sc_y = StandardScaler()
y=DataSource_train_dependent # scaling is not required for dependent variable


"""Decision Tree Regressor """

tree_reg = DecisionTreeRegressor()
tree_reg.fit(X,y)

filename = 'DecisionTree_TrainedModel.sav'
pickle.dump(tree_reg, open(filename, 'wb'))

节目二

from sklearn.tree import DecisionTreeRegressor # import for Decision Tree Algorithm
from sklearn.preprocessing import StandardScaler
from sklearn.tree import DecisionTreeRegressor # import for Decision Tree Algorithm
import pandas as pd

testdata=pd.read_excel("ASML Test  Stock Predict.xlsx") # Load the test data

sc_X = StandardScaler()
X_test=sc_X.transform(testdata.values) # scale the independent variables for test data



loaded_model = pickle.load(open('DecisionTree_TrainedModel.sav', 'rb'))
decision_predictions = loaded_model.predict(X_test) # Predict the value of dependent variable
print("The prediction by Decision Treemodel is " , decision_predictions )

由于我在程序 1 中有“fit_transform”并保存了 model,因此在加载 model 后的第二个程序中,我只转换了自变量。

运行第二个程序时出现错误消息“sklearn.exceptions.NotFittedError:尚未安装此 StandardScaler 实例。在使用此估算器之前,请使用适当的 arguments 调用 'fit'。”

请建议。 据我了解,我只需要转换而不适合测试自变量。

您还必须腌制训练有素的 StandardScaler:

# train and pickle
sc = StandardScaler()
X = sc.fit_transform(DataSource_train_independent.values)

tree_reg = DecisionTreeRegressor()
tree_reg.fit(X, y)

pickle.dump(sc, open('StandardScaler.pk', 'wb'))
pickle.dump(tree_reg, open('DecisionTree.pk', 'wb'))

# load and predict
sc = pickle.load(open('StandardScaler.pk', 'rb'))
model = pickle.load(open('DecisionTree.pk', 'rb'))

X_test = sc.transform(testdata.values)
predictions = model.predict(X_test)

更好的方法是将所有步骤包装在单个管道中:

pipeline = Pipeline(steps=[('sc', StandardScaler()), 
                           ('tree_reg', DecisionTreeRegressor())])

pipeline.fit(X, y)
pipeline.predict(testdata.values)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM