[英]100% RAM usage in google colab
我目前正在將 google colab 用於我的手語識別深度學習項目之一 model,我正在加載我從 Google Drive 創建的自定義數據集。 我的數據集包含不同的字母文件夾,其中包含各個字母的符號。
這只是我用來創建訓練數據的代碼的一部分
training_data = []
def create_training_data():
for category in CATEGORIES:
path = os.path.join(DATADIR,category) # create path to image of respective alphabet
class_num = CATEGORIES.index(category) # get the classification for each alphabet A : 0, C : 1, D : 2,...
for img in tqdm(os.listdir(path)): # iterate over each image
img_array = cv2.imread(os.path.join(path,img) ,cv2.IMREAD_GRAYSCALE) # convert to array
training_data.append([img_array, class_num]) # add this to our training_data
create_training_data()
X = []
y = []
for features,label in training_data:
X.append(np.array(features))
y.append(label)
但是只是這個過程占用了所有可用的 RAM,那么我有什么辦法可以最大限度地減少 RAM 的使用?
我無法復制您的訓練數據集,因此請謹慎對待。 如果您使用生成器生成訓練數據而不是將其構建為列表,那么您應該消除一半的 memory 使用量。 您仍需支付 X 和 Y 的 memory 成本,因此該技術可能不足以解決您的問題。
def iter_training_data():
for category in CATEGORIES:
path = os.path.join(DATADIR,category) # create path to image of respective alphabet
class_num = CATEGORIES.index(category) # get the classification for each alphabet A : 0, C : 1, D : 2,...
for img in tqdm(os.listdir(path)): # iterate over each image
img_array = cv2.imread(os.path.join(path,img) ,cv2.IMREAD_GRAYSCALE) # convert to array
yield [img_array, class_num]
X = []
y = []
for features,label in iter_training_data():
X.append(np.array(features))
y.append(label)
它占用了所有可用的 RAM,因為您只需將所有數據復制到其中。
使用 PyTorch 中的DataLoader
並定義批處理的大小(因為不一次使用所有數據)可能更容易。
import torch
import torchvision
from torchvision import transforms
train_transforms = transforms.Compose([
# transforms.Resize((256, 256)), # might also help in some way, if resize is allowed in your task
transforms.ToTensor() ])
train_dir = '/path/to/train/data/'
train_dataset = torchvision.datasets.ImageFolder(train_dir, train_transforms)
batch_size = 32
train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size )
然后,在訓練階段,您可以執行以下操作:
# ...
for inputs, labels in tqdm(train_dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
# ...
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.