简体   繁体   English

我如何迭代估计器?

[英]How do I iterate over estimators?

Given an ensemble estimator , I would like to iterate over the contents of its estimators_ field.给定一个ensemble estimator ,我想遍历它的estimators_字段的内容。

The problem is that the field can have a very different structure.问题是该字段可能具有非常不同的结构。

Eg, for a GradientBoostingClassifier it is a rank-2 numpy.ndarray (so I can use nditer ) while for a RandomForestClassifier it is a simple list .例如,对于GradientBoostingClassifier它是 rank-2 numpy.ndarray (所以我可以使用nditer )而对于RandomForestClassifier它是一个简单的list

Can I do better than this:我能做得比这更好吗:

import numpy as np
def iter_estimators(estimators):
    if isinstance(estimators, np.ndarray):
        return map(lambda x: x[()], np.nditer(estimators, flags=["refs_ok"]))
    return iter(estimators)

I suppose you could use np.asarray to conveniently ensure the object is an ndarray.我想你可以使用np.asarray来方便地确保 object 是一个 ndarray。 Then use ndarray.flat to get an iterator over the flattened array.然后使用ndarray.flat在展平数组上获取迭代器。

>>> estimators = model.estimators_
>>> array = np.asarray(estimators)
>>> iterator = array.flat
>>> iterator
<numpy.flatiter at 0x7f84f48f8e00>

A numpy -agnostic solution isnumpy无关的解决方案是

def iter_nested(obj):
    """Iterate over all iterable sub-objects.
    https://stackoverflow.com/q/58615038/850781"""
    try:
        for o1 in obj:
            for o2 in iter_nested(o1):
                yield o2
    except TypeError:           # ... object is not iterable
        yield obj

See also也可以看看

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM