繁体   English   中英

使用scikit-learn随机森林的数据集不平衡的问题?

[英]Problems with an unbalanced dataset with scikit-learn Random forest?

我有一个不平衡的文本数据集,它的外观如下:

label | texts(documents)
----------
5     |1190
4     |839
3     |239
1     |204
2     |127

我试图使用fit(X, y[, sample_weight])参数,但我在文档中没有理解这是如何预期的。 我尝试了以下方法:

from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import balance_weights

classifier=RandomForestClassifier(n_estimators=10,criterion='entropy')
classifier.fit(X_train, y_train,sample_weight = balance_weights(y))
prediction = classifier.predict(X_test)

但是我得到了这个例外:

/usr/local/lib/python2.7/site-packages/sklearn/utils/__init__.py:93: DeprecationWarning: Function balance_weights is deprecated; balance_weights is an internal function and will be removed in 0.16
  warnings.warn(msg, category=DeprecationWarning)
Traceback (most recent call last):
  File "/Users/user/RF_classification.py", line 34, in <module>
    classifier.fit(X_train, y_train,sample_weight = balance_weights(y))
  File "/usr/local/lib/python2.7/site-packages/sklearn/ensemble/forest.py", line 279, in fit
    for i in range(n_jobs))
  File "/usr/local/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 653, in __call__
    self.dispatch(function, args, kwargs)
  File "/usr/local/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 400, in dispatch
    job = ImmediateApply(func, args, kwargs)
  File "/usr/local/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 138, in __init__
    self.results = func(*args, **kwargs)
  File "/usr/local/lib/python2.7/site-packages/sklearn/ensemble/forest.py", line 85, in _parallel_build_trees
    curr_sample_weight *= sample_counts
ValueError: operands could not be broadcast together with shapes (2599,) (1741,) (2599,) 

如何平衡这个“不平衡数据”的估算器?

更新到0.16-dev。 随机森林现在支持class_weight="auto" ,它基本上为你自动重新平衡类。

我认为问题是您在完整数据集上使用了balanced_weights y之前,你把它分解成测试和训练集。 尝试:

classifier.fit(X_train, y_train,sample_weight = balance_weights(y_train))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM