简体   繁体   中英

How to check if a classifier belongs to sklearn.tree?

Suppose, I have a trained model, and I would like to check whether the model is a tree-based classifier. What is the best way to determine it?

eg I'm looking for something following:

import sklearn
from imaginarypackage import listmodules
if type(clf).__name__ in listmodules(sklearn.tree)

I have tried:

>>> import pkgutil
>>> 'DecisionTreeClassifier' in pkgutil.iter_modules(["sklearn.tree"])
>>> False

I understand not all the tree-based models (eg RandomForest) are under skelarn.tree. Hence, having a generic solution will be of very much help.

Thanks in advance!

Check type of tree model

As @Alexander Santos suggests, you can use the method from this answer to check which module your class belongs to. As far as I can tell, the tree based models are either a part of sklearn.tree or sklearn.ensemble._tree modules.

# Method 1: check if object type has __module__ attribute
module = getattr(clf, '__module__', '')

if module.startswith('sklearn.tree') or module.startswith('sklearn.ensemble._tree'):
    print("clf is a tree model")

Alternatively, a less python-esque method is to convert the type to a string and perform the same comparison.

# Method 2: convert type to string
type_ = str(type(clf))

if "sklearn.tree" in type_ or "sklearn.ensemble._tree" in type_:
    print("Clf is probably a tree model")

You can obviously rewrite this more efficiently if you need to test against many more than just two modules.

Alternative 'hack'

By inspecting the methods of DecisionTree , RandomForest and ExtraTrees regressor and classifiers using dir(clf) , it appears all the models you want to test for have methods such as:

  • min_samples_leaf
  • min_weight_fraction_leaf
  • max_leaf_nodes

So if you really needed one check to validate your model type, you can inspect the model's methods:

attributes = dir(clf)

check_for_list = ['min_samples_leaf', 'min_weight_fraction_leaf', 'max_leaf_nodes']

verdict = False
for check in check_for_list:
    if check in attributes:
        verdict = True
        break

if verdict:
    print("clf is probably a tree-based model.")

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM