简体   繁体   中英

Making a ML model scikit-learn compatible

I want to make this ML model scikit-learn compatible: https://github.com/manifoldai/merf

To do that, I followed the instructions here:https://danielhnyk.cz/creating-your-own-estimator-scikit-learn/ and imported from sklearn.base import BaseEstimator, RegressorMixin and inherited from them like so: class MERF(BaseEstimator, RegressorMixin):

However, when I check for scikit-learn compatibility:

from sklearn.utils.estimator_checks import check_estimator

import merf
check_estimator(merf)

I get this error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 500, in check_estimator
    for estimator, check in checks_generator:
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 340, in _generate_instance_checks
    yield from ((estimator, partial(check, name))
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 340, in <genexpr>
    yield from ((estimator, partial(check, name))
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 232, in _yield_all_checks
    tags = estimator._get_tags()
AttributeError: module 'merf' has no attribute '_get_tags'

How do I make this model scikit-learn compatible?

From the docs , check_estimator is used to "Check if estimator adheres to scikit-learn conventions."

This estimator will run an extensive test-suite for input validation, shapes, etc, making sure that the estimator complies with scikit-learn conventions as detailed in Rolling your own estimator. Additional tests for classifiers, regressors, clustering or transformers will be run if the Estimator class inherits from the corresponding mixin from sklearn.base.

So check_estimator is more than just a compatibility check, it also checks if you follow all the conventions etc.

You can read up on rolling your own estimator to make sure you follow the convention.

And then you need to pass an instance of your estimator class to check esimator like check_estimator(MERF()) . To actually make it follow all the conventions you have to solve every error it throws and fix them one by one.

For example one such check is that the __init__ method only set those attributes that it accepts as parameters.

MERF class violates that:

    def __init__(
        self,
        fixed_effects_model=RandomForestRegressor(n_estimators=300, n_jobs=-1),
        gll_early_stop_threshold=None,
        max_iterations=20,
    ):
        self.gll_early_stop_threshold = gll_early_stop_threshold
        self.max_iterations = max_iterations

        self.cluster_counts = None
        # Note fixed_effects_model must already be instantiated when passed in.
        self.fe_model = fixed_effects_model
        self.trained_fe_model = None
        self.trained_b = None

        self.b_hat_history = []
        self.sigma2_hat_history = []
        self.D_hat_history = []
        self.gll_history = []
        self.val_loss_history = []

It is setting attributes such as self.b_hat_history even though they are not parameters.

There are lots of other checks like this.

My personal advice is to not check all these conditions unless necessary, just inherit the Mixins and the Base classes, implement the needed methods and use the model.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM