簡體   English   中英

如何使用 pytorch 的 cat function 進行 K 折驗證(即將 pytorch 塊的列表連接在一起)

[英]How to use pytorch's cat function for K-Fold Validation (i.e. concatenate a list of pytorch chunks together)

我正在為 k 折交叉驗證拆分數據集,但在使用 Pytorch 的 stack/cat 函數連接張量列表時遇到問題。

首先,我使用 .chunk 方法將訓練集和測試集分成塊,如下所示

  x_train_folds = torch.chunk(x_train, num_folds)
  y_train_folds = torch.chunk(y_train, num_folds)

其中 x_train 是 torch.Size([5000, 3, 32, 32]) 的張量,y_train 是 torch.Size([5000]) 的張量

x_train_folds 和 y_train_folds 現在是 num_folds 張量的元組

然后,我需要設置一系列嵌套循環來遍歷 K 的不同值和各種折疊,同時始終從訓練集中排除一個折疊以在測試/驗證時使用:

  for k in k_choices:
    k_to_accuracies[k] = [] # create empty space to append for a given k-value
    for fold in range(num_folds):
      # create training sets by excluding the current loop index fold and using that as the test set
      x_train_cross_val = torch.cat((x_train_folds[:fold], x_train_folds[fold+1:]), 0)
      y_train_cross_val = torch.cat((y_train_folds[:fold], y_train_folds[fold+1:]), 0)
      classifier = KnnClassifier(x_train_cross_val, y_train_cross_val)
      k_to_accuracies[k].append(classifier.check_accuracy(x_train_folds[fold], y_train_folds[fold], k=k))

如您所見,我總是從原始訓練集中跳過一倍以用於驗證。 這是標准的 K 折交叉驗證。

不幸的是,我收到了以下我似乎無法弄清楚的錯誤: TypeError: expected Tensor as element 0 in argument 0, but got tuple

正如您在 API 列表中看到的那樣,.cat 似乎需要一個張量元組,這就是我所擁有的。 https://pytorch.org/docs/stable/torch.html#torch.cat

有沒有人有什么建議?

非常感謝 - 德魯

嘗試:

x_train_cross_val = torch.cat((*x_train_folds[:fold], *x_train_folds[fold+1:]), 0)
y_train_cross_val = torch.cat((*y_train_folds[:fold], *y_train_folds[fold+1:]), 0)

torch.cat接收一個元組,其元素是torch.Tensor類型。 但是,您的元組x_train_folds[:fold]中的元素仍然是tuple 因此,您需要刪除張量的tuple “裝飾器”。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM