[英]How to check if a classifier belongs to sklearn.tree?
Suppose, I have a trained model, and I would like to check whether the model is a tree-based classifier.假设,我有一个训练有素的 model,我想检查 model 是否是基于树的分类器。 What is the best way to determine it?确定它的最佳方法是什么?
eg I'm looking for something following:例如,我正在寻找以下内容:
import sklearn
from imaginarypackage import listmodules
if type(clf).__name__ in listmodules(sklearn.tree)
I have tried:我努力了:
>>> import pkgutil
>>> 'DecisionTreeClassifier' in pkgutil.iter_modules(["sklearn.tree"])
>>> False
I understand not all the tree-based models (eg RandomForest) are under skelarn.tree.我了解并非所有基于树的模型(例如 RandomForest)都在 skelarn.tree 下。 Hence, having a generic solution will be of very much help.因此,拥有一个通用的解决方案将非常有帮助。
Thanks in advance!提前致谢!
As @Alexander Santos suggests, you can use the method from this answer to check which module your class belongs to.正如@Alexander Santos 建议的那样,您可以使用此答案中的方法来检查您的 class 属于哪个模块。 As far as I can tell, the tree based models are either a part of sklearn.tree
or sklearn.ensemble._tree
modules.据我所知,基于树的模型是sklearn.tree
或sklearn.ensemble._tree
模块的一部分。
# Method 1: check if object type has __module__ attribute
module = getattr(clf, '__module__', '')
if module.startswith('sklearn.tree') or module.startswith('sklearn.ensemble._tree'):
print("clf is a tree model")
Alternatively, a less python-esque method is to convert the type
to a string and perform the same comparison.或者,一种不太像 python 的方法是将type
转换为字符串并执行相同的比较。
# Method 2: convert type to string
type_ = str(type(clf))
if "sklearn.tree" in type_ or "sklearn.ensemble._tree" in type_:
print("Clf is probably a tree model")
You can obviously rewrite this more efficiently if you need to test against many more than just two modules.如果您需要针对两个以上的模块进行测试,您显然可以更有效地重写它。
By inspecting the methods of DecisionTree
, RandomForest
and ExtraTrees
regressor and classifiers using dir(clf)
, it appears all the models you want to test for have methods such as:通过使用dir(clf)
检查DecisionTree
、 RandomForest
和ExtraTrees
回归器和分类器的方法,似乎您要测试的所有模型都具有以下方法:
min_samples_leaf
min_weight_fraction_leaf
max_leaf_nodes
So if you really needed one check to validate your model type, you can inspect the model's methods:因此,如果您真的需要一项检查来验证您的 model 类型,您可以检查模型的方法:
attributes = dir(clf)
check_for_list = ['min_samples_leaf', 'min_weight_fraction_leaf', 'max_leaf_nodes']
verdict = False
for check in check_for_list:
if check in attributes:
verdict = True
break
if verdict:
print("clf is probably a tree-based model.")
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.