[英]How do I iterate over estimators?
Given an ensemble estimator , I would like to iterate over the contents of its estimators_
field.给定一个ensemble estimator ,我想遍历它的estimators_
字段的内容。
The problem is that the field can have a very different structure.问题是该字段可能具有非常不同的结构。
Eg, for a GradientBoostingClassifier
it is a rank-2 numpy.ndarray
(so I can use nditer
) while for a RandomForestClassifier
it is a simple list
.例如,对于GradientBoostingClassifier
它是 rank-2 numpy.ndarray
(所以我可以使用nditer
)而对于RandomForestClassifier
它是一个简单的list
。
Can I do better than this:我能做得比这更好吗:
import numpy as np
def iter_estimators(estimators):
if isinstance(estimators, np.ndarray):
return map(lambda x: x[()], np.nditer(estimators, flags=["refs_ok"]))
return iter(estimators)
I suppose you could use np.asarray
to conveniently ensure the object is an ndarray.我想你可以使用np.asarray
来方便地确保 object 是一个 ndarray。 Then use ndarray.flat
to get an iterator over the flattened array.然后使用ndarray.flat
在展平数组上获取迭代器。
>>> estimators = model.estimators_
>>> array = np.asarray(estimators)
>>> iterator = array.flat
>>> iterator
<numpy.flatiter at 0x7f84f48f8e00>
A numpy
-agnostic solution is与numpy
无关的解决方案是
def iter_nested(obj):
"""Iterate over all iterable sub-objects.
https://stackoverflow.com/q/58615038/850781"""
try:
for o1 in obj:
for o2 in iter_nested(o1):
yield o2
except TypeError: # ... object is not iterable
yield obj
See also也可以看看
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.