繁体   English   中英

Scikit-Learn SVR预测始终具有相同的价值

[英]Scikit-Learn SVR Prediction Always Gives the Same Value

我即将使用Scikit-Learn中的支持向量回归来预测IMDB得分(电影费率)。 问题是它总是为每个输入提供相同的预测结果。

当我预测使用数据训练时,它会给出各种结果。 但是在使用数据测试时,它总是给出相同的值。

数据培训预测:

数据培训预测http://image.prntscr.com/image/647c2a2db7c3419dbad364b340e2a49c.png

数据测试预测:

数据测试预测http://image.prntscr.com/image/d4e2c8ff5d6447cfb73e888e79a897a5.png

以下是数据集的链接: IMDB 5000 Movie Dataset

我的代码:

import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import numpy as np
import seaborn as sb
from sklearn import metrics as met


df = pd.read_csv("movie_metadata.csv")
df.head()


original = df.shape[0]
df = df.drop_duplicates(["movie_title"])
notDuplicated = df.shape[0]
df.reset_index(drop = True, inplace = True)
print(original, notDuplicated)


df["num_critic_for_reviews"].fillna(0, inplace = True)
df["num_critic_for_reviews"] = df["num_critic_for_reviews"].astype("int")

df["director_facebook_likes"].fillna(0, inplace = True)
df["director_facebook_likes"] = df["director_facebook_likes"].astype("int")

df["actor_3_facebook_likes"].fillna(0, inplace = True)
df["actor_3_facebook_likes"] = df["actor_3_facebook_likes"].astype(np.int64)

df["actor_2_facebook_likes"].fillna(0, inplace = True)
df["actor_2_facebook_likes"] = df["actor_2_facebook_likes"].astype(np.int64)

df["actor_1_facebook_likes"].fillna(0, inplace = True)
df["actor_1_facebook_likes"] = df["actor_1_facebook_likes"].astype(np.int64)

df["movie_facebook_likes"].fillna(0, inplace = True)
df["movie_facebook_likes"] = df["movie_facebook_likes"].astype(np.int64)

df["content_rating"].fillna("Not Rated", inplace = True)
df["content_rating"].replace('-', "Not Rated", inplace = True)
df["content_rating"] = df["content_rating"].astype("str")

df["imdb_score"].fillna(0.0, inplace = True)

df["title_year"].fillna(0, inplace = True)
df["title_year"].replace("NA", 0, inplace = True)
df["title_year"] = df["title_year"].astype("int")

df["genres"].fillna("", inplace = True)
df["genres"] = df["genres"].astype("str")


df2 = df[df["title_year"] >= 1980]
df2.reset_index(drop = True, inplace = True)

nRow = len(df2)
print("Number of data:", nRow)
nTrain = np.int64(np.floor(0.7 * nRow))
nTest = nRow - nTrain
print("Number of data training (70%):", nTrain, "\nNumber of data testing (30%):", nTest)

dataTraining = df2[0:nTrain]
dataTesting = df2[nTrain:nRow]
dataTraining.reset_index(drop = True, inplace = True)
dataTesting.reset_index(drop = True, inplace = True)


xTrain = dataTraining[["num_critic_for_reviews", "director_facebook_likes", "actor_3_facebook_likes", "actor_2_facebook_likes", "actor_1_facebook_likes", "movie_facebook_likes"]]
yTrain = dataTraining["imdb_score"]

xTest = dataTesting[["num_critic_for_reviews", "director_facebook_likes", "actor_3_facebook_likes", "actor_2_facebook_likes", "actor_1_facebook_likes", "movie_facebook_likes"]]
yTest = dataTesting["imdb_score"]

movieTitle = dataTesting["movie_title"].reset_index(drop = True)


from sklearn.svm import SVR

svrModel = SVR(kernel = "rbf", C = 1e3, gamma = 0.1, epsilon = 0.1)
svrModel.fit(xTrain,yTrain)


predicted = svrModel.predict(xTest)
[print(movieTitle[i], ":", predicted[i]) for i in range(10)]

gamma0.1更改为1e-8同时保持其他所有内容相同。

当gamma设置为0.1时,唯一预测的数量为8,它们都接近6.37。 当gamma设置为1e-8时,输出1366个唯一预测(xTest包含1368个总样本)。

为什么伽玛很重要?

直观地,伽马参数定义了单个训练样例的影响达到了多远,低值意味着“远”,高值意味着“接近”。 伽马参数可以被视为由模型选择的样本作为支持向量的影响半径的倒数。

RBF SVM参数有更深入的解释和示例。

这里也有类似的解释: 多类分类中Scikit SVM的输出总是给出相同的标签

就个人而言,我会在脚本的底部使用GridSearchCV。 以下是查找理想gamma值和C值的示例:

from sklearn.svm import SVR
from sklearn.model_selection import GridSearchCV

#svrModel = SVR(kernel = "rbf", C = 1e3, gamma = 1e-8, epsilon = 0.1)
#svrModel.fit(xTrain,yTrain)


#predicted = svrModel.predict(xTest)
#[print(movieTitle[i], ":", predicted[i]) for i in range(10)]

#print('Unique predictions:', np.unique(predicted))

parameters = {
    "kernel": ["rbf"],
    "C": [1,10,10,100,1000],
    "gamma": [1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
    }

grid = GridSearchCV(SVR(), parameters, cv=5, verbose=2)
grid.fit(xTrain, yTrain)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM