简体   繁体   English

线性回归 model 散点图 plot

[英]Linear Regression model scatter plot

I have the following data.table Dataset我有以下 data.table数据集

This is the code I have to train the dataset这是我必须训练数据集的代码

for index in df.index:
    if df.loc[index,'Weather_Condition']=='Fog':
        df.loc[index,'Weather_Condition'] = '1'
    elif df.loc[index,'Weather_Condition']=='Fair':
        df.loc[index,'Weather_Condition'] = '2'

df2 = df.loc[((df["Weather_Condition"] == '1') | (df["Weather_Condition"] == '2'))]  

X = df2[['Weather_Condition']]
Y = df2[['Severity']]



X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = .20, random_state = 40)

regr = linear_model.LinearRegression()
regr.fit(X_train, Y_train)
predicted = regr.predict(X_test)


plt.scatter(X_train, Y_train, color = 'red')
plt.plot(X_train, predicted, color = 'blue')

plt.show()

First I take the weather conditions and replace their values with numerical values instead.首先,我获取天气状况并将其值替换为数值。

When I plot I get当我 plot 我得到

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-54-b6586f8d0c2f> in <module>
     18 
     19 
---> 20 plt.scatter(X_train, Y_train, color = 'red')
     21 plt.plot(X_train, predicted, color = 'blue')
     22 

/opt/anaconda3/lib/python3.8/site-packages/matplotlib/pyplot.py in scatter(x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, plotnonfinite, data, **kwargs)
   2888         verts=cbook.deprecation._deprecated_parameter,
   2889         edgecolors=None, *, plotnonfinite=False, data=None, **kwargs):
-> 2890     __ret = gca().scatter(
   2891         x, y, s=s, c=c, marker=marker, cmap=cmap, norm=norm,
   2892         vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths,

/opt/anaconda3/lib/python3.8/site-packages/matplotlib/__init__.py in inner(ax, data, *args, **kwargs)
   1445     def inner(ax, *args, data=None, **kwargs):
   1446         if data is None:
-> 1447             return func(ax, *map(sanitize_sequence, args), **kwargs)
   1448 
   1449         bound = new_sig.bind(ax, *args, **kwargs)

/opt/anaconda3/lib/python3.8/site-packages/matplotlib/cbook/deprecation.py in wrapper(*inner_args, **inner_kwargs)
    409                          else deprecation_addendum,
    410                 **kwargs)
--> 411         return func(*inner_args, **inner_kwargs)
    412 
    413     return wrapper

/opt/anaconda3/lib/python3.8/site-packages/matplotlib/axes/_axes.py in scatter(self, x, y, s, c, marker, cmap, norm, vmin, vmax, alpha, linewidths, verts, edgecolors, plotnonfinite, **kwargs)
   4430         # Process **kwargs to handle aliases, conflicts with explicit kwargs:
   4431 
-> 4432         self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs)
   4433         x = self.convert_xunits(x)
   4434         y = self.convert_yunits(y)

/opt/anaconda3/lib/python3.8/site-packages/matplotlib/axes/_base.py in _process_unit_info(self, xdata, ydata, kwargs)
   2187             return kwargs
   2188 
-> 2189         kwargs = _process_single_axis(xdata, self.xaxis, 'xunits', kwargs)
   2190         kwargs = _process_single_axis(ydata, self.yaxis, 'yunits', kwargs)
   2191         return kwargs

/opt/anaconda3/lib/python3.8/site-packages/matplotlib/axes/_base.py in _process_single_axis(data, axis, unit_name, kwargs)
   2170                 # We only need to update if there is nothing set yet.
   2171                 if not axis.have_units():
-> 2172                     axis.update_units(data)
   2173 
   2174             # Check for units in the kwargs, and if present update axis

/opt/anaconda3/lib/python3.8/site-packages/matplotlib/axis.py in update_units(self, data)
   1464         neednew = self.converter != converter
   1465         self.converter = converter
-> 1466         default = self.converter.default_units(data, self)
   1467         if default is not None and self.units is None:
   1468             self.set_units(default)

/opt/anaconda3/lib/python3.8/site-packages/matplotlib/category.py in default_units(data, axis)
    105         # the conversion call stack is default_units -> axis_info -> convert
    106         if axis.units is None:
--> 107             axis.set_units(UnitData(data))
    108         else:
    109             axis.units.update(data)

/opt/anaconda3/lib/python3.8/site-packages/matplotlib/category.py in __init__(self, data)
    174         self._counter = itertools.count()
    175         if data is not None:
--> 176             self.update(data)
    177 
    178     @staticmethod

/opt/anaconda3/lib/python3.8/site-packages/matplotlib/category.py in update(self, data)
    207         # check if convertible to number:
    208         convertible = True
--> 209         for val in OrderedDict.fromkeys(data):
    210             # OrderedDict just iterates over unique values in data.
    211             cbook._check_isinstance((str, bytes), value=val)

TypeError: unhashable type: 'numpy.ndarray'

Any suggestion on how to fix the error?关于如何修复错误的任何建议? I made sure to check X.shape and Y.shape.我确保检查了 X.shape 和 Y.shape。 Both are the same.两者是一样的。 I have tried to remove the double bracket the Y=severity step as well to change the shape but the same error persists我也尝试删除 Y=severity 步骤中的双括号以更改形状,但仍然存在相同的错误

Edit: Added pictures of X,Y and their shapes编辑:添加了 X、Y 及其形状的图片在此处输入图像描述 在此处输入图像描述 在此处输入图像描述 在此处输入图像描述

So I assume you have dataframes like this, with the following shapes (eg a shape that is equal to (some value, 1) ).所以我假设你有这样的数据框,具有以下形状(例如等于(some value, 1)的形状)。

X = pd.DataFrame(index=np.arange(10), data=np.random.randint(0,3, size=10))
print(X.shape)
X

在此处输入图像描述

Y = pd.DataFrame(index=np.arange(10), data=np.random.normal(0,1, size=10))
print(Y.shape)
Y

在此处输入图像描述

In order to plot them, you will have to do the following:为了 plot 它们,您必须执行以下操作:

(1) transform them into 1 dimensional numpy arrays (1)将它们转化为一维的 numpy arrays

X = X.to_numpy().reshape(-1)
Y = Y.to_numpy().reshape(-1)
print(X)
print(Y)

outputs产出

[0 1 2 0 1 0 0 0 2 1]
[ 0.55487511  0.98492395  0.40978917  0.71796476  1.49174604  0.23563932
  0.36504642 -0.6218809   0.90840444  0.1369467 ]

(2) Now you can plot them (2)现在你可以plot他们

plt.scatter(X, Y)

在此处输入图像描述

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM