简体   繁体   English

使 ML model scikit-learn 兼容

[英]Making a ML model scikit-learn compatible

I want to make this ML model scikit-learn compatible: https://github.com/manifoldai/merf我想让这个 ML model scikit-learn 兼容: https://github.com/manifoldai/merf

To do that, I followed the instructions here:https://danielhnyk.cz/creating-your-own-estimator-scikit-learn/ and imported from sklearn.base import BaseEstimator, RegressorMixin and inherited from them like so: class MERF(BaseEstimator, RegressorMixin):为此,我按照此处的说明进行操作:https://danielhnyk.cz/creating-your-own-estimator-scikit-learn/from sklearn.base import BaseEstimator, RegressorMixin并像这样从它们继承: class MERF(BaseEstimator, RegressorMixin):

However, when I check for scikit-learn compatibility:但是,当我检查 scikit-learn 兼容性时:

from sklearn.utils.estimator_checks import check_estimator

import merf
check_estimator(merf)

I get this error:我收到此错误:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 500, in check_estimator
    for estimator, check in checks_generator:
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 340, in _generate_instance_checks
    yield from ((estimator, partial(check, name))
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 340, in <genexpr>
    yield from ((estimator, partial(check, name))
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 232, in _yield_all_checks
    tags = estimator._get_tags()
AttributeError: module 'merf' has no attribute '_get_tags'

How do I make this model scikit-learn compatible?如何使这个 model scikit-learn 兼容?

From the docs , check_estimator is used to "Check if estimator adheres to scikit-learn conventions."文档中, check_estimator用于“检查估计器是否符合 scikit-learn 约定。”

This estimator will run an extensive test-suite for input validation, shapes, etc, making sure that the estimator complies with scikit-learn conventions as detailed in Rolling your own estimator.此估算器将运行一个广泛的测试套件,用于输入验证、形状等,确保估算器符合滚动您自己的估算器中详述的 scikit-learn 约定。 Additional tests for classifiers, regressors, clustering or transformers will be run if the Estimator class inherits from the corresponding mixin from sklearn.base.如果 Estimator class 继承自 sklearn.base 的相应 mixin,则将运行分类器、回归器、聚类或转换器的附加测试。

So check_estimator is more than just a compatibility check, it also checks if you follow all the conventions etc.所以check_estimator不仅仅是一个兼容性检查,它还检查你是否遵循所有的约定等。

You can read up on rolling your own estimator to make sure you follow the convention.您可以阅读滚动您自己的估算器以确保您遵循约定。

And then you need to pass an instance of your estimator class to check esimator like check_estimator(MERF()) .然后你需要传递你的估计器 class 的实例来检查像check_estimator(MERF())这样的估计器。 To actually make it follow all the conventions you have to solve every error it throws and fix them one by one.要真正让它遵循所有约定,您必须解决它抛出的每个错误并一个一个地修复它们。

For example one such check is that the __init__ method only set those attributes that it accepts as parameters.例如,一项这样的检查是__init__方法只设置它接受作为参数的那些属性。

MERF class violates that: MERF class 违反:

    def __init__(
        self,
        fixed_effects_model=RandomForestRegressor(n_estimators=300, n_jobs=-1),
        gll_early_stop_threshold=None,
        max_iterations=20,
    ):
        self.gll_early_stop_threshold = gll_early_stop_threshold
        self.max_iterations = max_iterations

        self.cluster_counts = None
        # Note fixed_effects_model must already be instantiated when passed in.
        self.fe_model = fixed_effects_model
        self.trained_fe_model = None
        self.trained_b = None

        self.b_hat_history = []
        self.sigma2_hat_history = []
        self.D_hat_history = []
        self.gll_history = []
        self.val_loss_history = []

It is setting attributes such as self.b_hat_history even though they are not parameters.它正在设置诸如self.b_hat_history类的属性,即使它们不是参数。

There are lots of other checks like this.还有很多其他类似的检查。

My personal advice is to not check all these conditions unless necessary, just inherit the Mixins and the Base classes, implement the needed methods and use the model.我个人的建议是,除非必要,否则不要检查所有这些条件,只需继承 Mixins 和 Base 类,实现所需的方法并使用 model。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM