繁体   English   中英

测试 sklearn 模型是否已安装的最佳方法是什么?

[英]What's the best way to test whether an sklearn model has been fitted?

检查 sklearn 模型是否已安装的最优雅方法是什么? 即它的fit()函数是否在实例化后被调用。

您可以执行以下操作:

from sklearn.exceptions import NotFittedError

for model in models:
    try:
        model.predict(some_test_data)
    except NotFittedError as e:
        print(repr(e))

理想情况下,您会根据预期结果检查model.predict结果,但如果您只想知道模型是否拟合就足够了。

更新

一些评论者建议使用check_is_fitted 我认为check_is_fitted是一个内部方法 大多数算法会在他们的 predict 方法中调用check_is_fitted ,如果需要,这反过来可能会引发NotFittedError 直接使用check_is_fitted的问题在于它是特定于模型的,即您需要根据您的算法知道要检查哪些成员。 例如:

╔════════════════╦════════════════════════════════════════════╗
║ Tree models    ║ check_is_fitted(self, 'tree_')             ║
║ Linear models  ║ check_is_fitted(self, 'coefs_')            ║
║ KMeans         ║ check_is_fitted(self, 'cluster_centers_')  ║
║ SVM            ║ check_is_fitted(self, 'support_')          ║
╚════════════════╩════════════════════════════════════════════╝

等等。 所以总的来说,我建议调用model.predict()并让特定算法处理检查它是否已经安装的最佳方法。

我为分类器这样做:

def check_fitted(clf): 
    return hasattr(clf, "classes_")

这是一种贪婪的方法,但如果不是所有模型,它应该适用于大多数模型。 唯一一次这可能不起作用是对于在拟合之前设置以下划线结尾的属性的模型,我很确定这会违反 scikit-learn 约定,所以这应该没问题。

import inspect

def is_fitted(model):
        """Checks if model object has any attributes ending with an underscore"""
        return 0 < len( [k for k,v in inspect.getmembers(model) if k.endswith('_') and not k.startswith('__')] )

直接从 scikit-learn 源代码中check_is_fitted函数(类似于@david-marx 的逻辑,但更简单一点):

def is_fitted(model):
    '''
    Checks if a scikit-learn estimator/transformer has already been fit.
    
    
    Parameters
    ----------
    model: scikit-learn estimator (e.g. RandomForestClassifier) 
        or transformer (e.g. MinMaxScaler) object
        
    
    Returns
    -------
    Boolean that indicates if ``model`` has already been fit (True) or not (False).
    '''
    
    attrs = [v for v in vars(model)
             if v.endswith("_") and not v.startswith("__")]
    
    return len(attrs) != 0

此函数通过将其与模型的新实例进行比较来检查 scikit-learn 模型是否适合。

def is_fitted(model):
    return not len(dir(model)) == len(dir(type(model)()))

model = OneHotEncoder()
print(is_fitted(model)) #False
model.fit_transform(data)
print(is_fitted(model)) #True


 

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM