简体   繁体   English

用于多类对象检测的分层 K 折?

[英]Stratified K-Fold For Multi-Class Object Detection?

Updated更新

I've uploaded a dummy data set, link here .我上传了一个虚拟数据集,链接在这里 The df.head() : df.head()

在此处输入图片说明

It has 4 class in total and df.object.value_counts() :它总共有4 个类df.object.value_counts()

human    23
car      13
cat       5
dog       3

I want to do properly K-Fold validation splits over a multi-class object detection data set.我想在多类对象检测数据集上正确地进行K-Fold验证拆分。

Initial Approach初始方法

To achieve proper k-fold validation splits, I took the object counts and the number of bounding box into account.为了实现正确的 k 折验证拆分,我考虑了object countsbounding box的数量。 I understand, the K-fold splitting strategies mostly depends on the data set (meta information).据我了解, K-fold拆分策略主要取决于数据集(元信息)。 But for now with these dataset, I've tried something like as follows:但是现在有了这些数据集,我已经尝试了如下:

skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=101)
df_folds = main_df[['image_id']].copy()

df_folds.loc[:, 'bbox_count'] = 1
df_folds = df_folds.groupby('image_id').count()
df_folds.loc[:, 'object_count'] = main_df.groupby('image_id')['object'].nunique()

df_folds.loc[:, 'stratify_group'] = np.char.add(
    df_folds['object_count'].values.astype(str),
    df_folds['bbox_count'].apply(lambda x: f'_{x // 15}').values.astype(str)
)

df_folds.loc[:, 'fold'] = 0
for fold_number, (train_index, val_index) in enumerate(skf.split(X=df_folds.index, y=df_folds['stratify_group'])):
    df_folds.loc[df_folds.iloc[val_index].index, 'fold'] = fold_number

After the splitting, I've checked to ensure if it's working.拆分后,我检查了它是否正常工作。 And it seems Ok so far.到目前为止似乎还可以。

在此处输入图片说明

All the folds contain stratified k-fold samples, len(df_folds[df_folds['fold'] == fold_number].index) and no intersection to each other, set(A).intersection(B) where A and B are the index value ( image_id ) of two folds.所有的折叠都包含分层的k-fold样本, len(df_folds[df_folds['fold'] == fold_number].index)并且彼此没有交集, set(A).intersection(B)其中AB是索引值 ( image_id ) 的两倍。 But the issue seems like:但问题似乎是:

Fold 0 has total: 18 + 2 + 3 = 23 bbox
Fold 1 has total: 2 + 11 = 13 bbox
Fold 2 has total: 5 + 3 = 8 bbox

Concern忧虑

However, I couldn't ensure whether it's the proper way for this type of task in general.但是,我无法确定这是否是此类任务的正确方法。 I want some advice.我想要一些建议。 Is the above approach OK?上面的方法可以吗? or any issue?或任何问题? or there is some better approach!或者有一些更好的方法! Any sorts of suggestions would be appreciated.任何类型的建议将不胜感激。 Thanks.谢谢。

When creating a cross-validation split, we care about creating folds which have a good distribution of the various "cases" encountered in the data.在创建交叉验证拆分时,我们关心的是创建折叠,这些折叠对数据中遇到的各种“案例”具有良好的分布。

In your case, you decided to base your folds on the number of cars and the number of bounding boxes which is a good but limited choice.在您的情况下,您决定根据汽车的数量和边界框的数量进行折叠,这是一个不错但有限的选择。 So, if you can identify specific cases using your data/metadata, you might try to create smarter folds using it.因此,如果您可以使用数据/元数据识别特定情况,则可以尝试使用它创建更智能的折叠。

The most obvious choice is to balance object types (classes) in your folds, but you could go further.最明显的选择是平衡折叠中的对象类型(类),但您可以更进一步。

Here is the main idea, let's say you have images with cars encountered mostly in France, and others with cars encountered mostly in the US, it could be used to create good folds with a balanced number of french and us cars in each fold.这是主要思想,假设您的图像主要在法国遇到汽车,而其他汽车主要在美国遇到,它可以用来创建良好的折叠,每个折叠中都有平衡数量的法国和美国汽车。 Same could be done with weather conditions etc. Thus, each fold will contain representative data to learn from so that your network won't be biased for your task.天气条件等也可以这样做。因此,每个折叠都将包含可供学习的代表性数据,以便您的网络不会对您的任务产生偏见。 As a result, your model will be more robust to such potential real life changes in the data.因此,您的模型将对数据中此类潜在的现实生活变化更加稳健。

So, can you add some metadata to your cross-validation strategy to create a better CV?那么,您能否在交叉验证策略中添加一些元数据以创建更好的简历? If it's not the case, can you get information about potential corner cases using the x, y, w, h columns of your dataset?如果不是这种情况,您能否使用数据集的 x、y、w、h 列获取有关潜在极端情况的信息?

Then you should try to have balanced folds in terms of samples so that your scores are evaluated on the same sample size which will reduce variance and provide a better evaluation at the end.然后,您应该尝试在样本方面进行平衡的折叠,以便在相同的样本量上评估您的分数,这将减少方差并在最后提供更好的评​​估。

You can use StratifiedKFold() or StratifiedShuffleSplit() directly to split your data set using stratified sampling based on some categorical column.您可以直接使用 StratifiedKFold() 或 StratifiedShuffleSplit() 使用基于某些分类列的分层抽样来拆分数据集。

Dummy Data:虚拟数据:

import pandas as pd
import numpy as np

np.random.seed(43)
df = pd.DataFrame({'ID': (1,1,2,2,3,3),
               'Object': ('bus', 'car', 'bus', 'bus', 'bus', 'car'),
               'X' : np.random.randint(0, 10, 6),
               'Y' : np.random.randn(6)

})


df

Using StratifiedKFold()使用 StratifiedKFold()

from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=2)

for train_index, test_index in skf.split(df, df["Object"]):
        strat_train_set_1 = df.loc[test_index]
        strat_test_set_1 = df.loc[test_index]

print('train_set :', strat_train_set_1, '\n' , 'test_set :', strat_test_set_1)

Similarly, if you choose to use StratifiedShuffleSplit(), you can have同样,如果你选择使用 StratifiedShuffleSplit(),你可以有

from sklearn.model_selection import StratifiedShuffleSplit

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
# n_splits = Number of re-shuffling & splitting iterations.

for train_index, test_index in sss.split(df, df["Object"]):
 # split(X, y[, groups]) Generates indices to split data into training and test set.

        strat_train_set = df.loc[train_index]
        strat_test_set = df.loc[test_index]

print('train_set :', strat_train_set, '\n' , 'test_set :', strat_test_set)

I'd do this simply using KFold method of scikit-learn of python我会简单地使用KFold的 scikit-learn 的KFold方法来做到这一点

from numpy import array
from sklearn.model_selection import KFold
data = array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
kfold = KFold(3, True, 1)
for train, test in kfold.split(data):
    print('train: %s, test: %s' % (data[train], data[test]))

and please see if this might be helpful请看看是否有帮助

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM