繁体   English   中英

使 ML model scikit-learn 兼容

[英]Making a ML model scikit-learn compatible

我想让这个 ML model scikit-learn 兼容: https://github.com/manifoldai/merf

为此,我按照此处的说明进行操作:https://danielhnyk.cz/creating-your-own-estimator-scikit-learn/from sklearn.base import BaseEstimator, RegressorMixin并像这样从它们继承: class MERF(BaseEstimator, RegressorMixin):

但是,当我检查 scikit-learn 兼容性时:

from sklearn.utils.estimator_checks import check_estimator

import merf
check_estimator(merf)

我收到此错误:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 500, in check_estimator
    for estimator, check in checks_generator:
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 340, in _generate_instance_checks
    yield from ((estimator, partial(check, name))
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 340, in <genexpr>
    yield from ((estimator, partial(check, name))
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 232, in _yield_all_checks
    tags = estimator._get_tags()
AttributeError: module 'merf' has no attribute '_get_tags'

如何使这个 model scikit-learn 兼容?

文档中, check_estimator用于“检查估计器是否符合 scikit-learn 约定。”

此估算器将运行一个广泛的测试套件,用于输入验证、形状等,确保估算器符合滚动您自己的估算器中详述的 scikit-learn 约定。 如果 Estimator class 继承自 sklearn.base 的相应 mixin,则将运行分类器、回归器、聚类或转换器的附加测试。

所以check_estimator不仅仅是一个兼容性检查,它还检查你是否遵循所有的约定等。

您可以阅读滚动您自己的估算器以确保您遵循约定。

然后你需要传递你的估计器 class 的实例来检查像check_estimator(MERF())这样的估计器。 要真正让它遵循所有约定,您必须解决它抛出的每个错误并一个一个地修复它们。

例如,一项这样的检查是__init__方法只设置它接受作为参数的那些属性。

MERF class 违反:

    def __init__(
        self,
        fixed_effects_model=RandomForestRegressor(n_estimators=300, n_jobs=-1),
        gll_early_stop_threshold=None,
        max_iterations=20,
    ):
        self.gll_early_stop_threshold = gll_early_stop_threshold
        self.max_iterations = max_iterations

        self.cluster_counts = None
        # Note fixed_effects_model must already be instantiated when passed in.
        self.fe_model = fixed_effects_model
        self.trained_fe_model = None
        self.trained_b = None

        self.b_hat_history = []
        self.sigma2_hat_history = []
        self.D_hat_history = []
        self.gll_history = []
        self.val_loss_history = []

它正在设置诸如self.b_hat_history类的属性,即使它们不是参数。

还有很多其他类似的检查。

我个人的建议是,除非必要,否则不要检查所有这些条件,只需继承 Mixins 和 Base 类,实现所需的方法并使用 model。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM