[英]Shuffle and split 2 numpy arrays so as to maintain their ordering with respect to each other
I have 2 numpy arrays X and Y, with shape X: [4750, 224, 224, 3] and Y: [4750,1]. 我有2个numpy数组X和Y,形状为X:[4750、224、224、3]和Y:[4750,1]。
X is the training dataset and Y is the correct output label for each entry. X是训练数据集,Y是每个条目的正确输出标签。
I want to split the data into train and test so as to validate my machine learning model. 我想将数据分为训练和测试,以验证我的机器学习模型。 Therefore, I want to split them randomly so that they both have the correct ordering after random split is applied on X and Y. ie- every row of X is correctly has its corresponding label unchanged after the split. 因此,我想随机分割它们,以便在对X和Y进行随机分割后它们都具有正确的顺序。即-X的每一行在分割后正确地保持了其对应的标签不变。
How can I achieve the above objective ? 我如何实现上述目标?
This is how I would do it 这就是我会做的
def split(x, y, train_ratio=0.7):
x_size = x.shape[0]
train_size = int(x_size * train_ratio)
test_size = x_size - train_size
train_indices = np.random.choice(x_size, size=train_size, replace=False)
mask = np.zeros(x_size, dtype=bool)
mask[train_indices] = True
x_train, y_train = x[mask], y[mask]
x_test, y_test = x[~mask], y[~mask]
return (x_train, y_train), (x_test, y_test)
I simply choose the required number of indices I need (randomly) for my train set, remaining will be for the test set. 我只需(随机地)为火车组选择所需的索引数量,其余的将用于测试组。
Then use a mask to select the train and test samples. 然后使用遮罩选择训练和测试样本。
You can also use the scikit-learn train_test_split
to split your data using just 2 lines of code : 您还可以使用scikit-learn train_test_split
仅使用两行代码来拆分数据:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.33)
sklearn.model_selection.train_test_split
is a good choice! sklearn.model_selection.train_test_split
是一个不错的选择!
But to craft one of your own 但是要自己做一个
import numpy as np
def my_train_test_split(X, Y, train_ratio=0.8):
"""return X_train, Y_train, X_test, Y_test"""
n = X.shape[0]
split = int(n * train_ratio)
index = np.arange(n)
np.random.shuffle(index)
return X[index[:split]], Y[index[:split]], X[index[split:]], Y[index[split:]]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.