[英]sklean mean_squared_error ignores the squared argument, with multioutput='raw_values'
The documentation page for the mean squared error function from sklearn provides some examples on how to use the function. Including on how to use it for multioutput data and for calculating the RMSE.来自 sklearn 的均方误差 function 的文档页面提供了一些有关如何使用 function 的示例。包括如何将其用于多输出数据和计算 RMSE。 The problem is that this does not work when calculating the RMSE on multiple outputs.
问题是这在计算多个输出的 RMSE 时不起作用。
Here is the code I used:这是我使用的代码:
from sklearn.metrics import mean_squared_error
y_true = [[0.5, 1],[-1, 1],[7, -6]]
y_pred = [[0, 2],[-1, 2],[8, -5]]
mean_squared_error(y_true, y_pred) # This returns the MSE
#out: 0.7083333333333334
mean_squared_error(y_true, y_pred, squared=False) # And the RMSE works too
#out: 0.8416254115301732
mean_squared_error(y_true, y_pred, multioutput='raw_values') # I can use the MSE for multiple outputs
#out: array([0.41666667, 1. ])
mean_squared_error(y_true, y_pred, multioutput='raw_values', squared=False) # But not the RMSE
#out: array([0.41666667, 1. ])
# However
import numpy as np
np.sqrt(mean_squared_error(y_true, y_pred, multioutput='raw_values')) # Numpy gives the correct results
#out: array([0.64549722, 1. ])
Some specifications:一些规格:
Python 3.6.8 (default, Oct 7 2019, 12:59:55)
[GCC 8.3.0] on linux
sklearn.__version__
'0.22'
np.__version__
'1.17.4'
I looked at the source code but I don't see why this does not work.我查看了源代码,但不明白为什么这不起作用。
This is a known, now closed issue , that does not occur in the current version of sklearn 0.23.2
, as of this answer.这是一个已知的,现已关闭的问题,在当前版本的
sklearn 0.23.2
中不会出现,截至此答案。
This is not reproducible in numpy 1.19.1 and sklearn 0.23.2这在 numpy 1.19.1 和 sklearn 0.23.2 中不可重现
mean_squared_error(y_true, y_pred, multioutput='raw_values', squared=False)
and np.sqrt(mean_squared_error(y_true, y_pred, multioutput='raw_values'))
return the same value. mean_squared_error(y_true, y_pred, multioutput='raw_values', squared=False)
和np.sqrt(mean_squared_error(y_true, y_pred, multioutput='raw_values'))
返回相同的值。
The resolution is to upgrade.决议是升级。
If upgrading is not an option:如果升级不是一个选项:
return output_errors
→ return output_errors if squared else np.sqrt(output_errors)
return output_errors
→ return output_errors if squared else np.sqrt(output_errors)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.