简体   繁体   English

创建和使用 PyTorch DataLoader

[英]Creating and Use a PyTorch DataLoader

I am trying to create a PyTorch Dataset and DataLoader object using a sample data.我正在尝试使用示例数据创建 PyTorch 数据集和 DataLoader object。

This is the tab seperated dataset:这是制表符分隔的数据集:

1 0  0.171429  1 0 0  0.966805  0
0 1  0.085714  0 1 0  0.188797  1
1 0  0.000000  0 0 1  0.690871  2
1 0  0.057143  0 1 0  1.000000  1
0 1  1.000000  0 0 1  0.016598  2
1 0  0.171429  1 0 0  0.802905  0
0 1  0.171429  1 0 0  0.966805  1
1 0  0.257143  0 1 0  0.329876  0

This is the code to create the Dataset above and DataLoader object:这是创建上述数据集和 DataLoader object 的代码:

import numpy as np
import torch as T
device = T.device("cpu")  # to Tensor or Module

# ---------------------------------------------------

# predictors and label in same file
# data has been normalized and encoded like:
#   sex     age      region   income    politic
#   [0]     [2]       [3]      [6]       [7]
#   1 0   0.057143   0 1 0    0.690871    2

class PeopleDataset(T.utils.data.Dataset):

  def __init__(self, src_file, num_rows=None):
    x_tmp = np.loadtxt(src_file, max_rows=num_rows,
      usecols=range(0,7), delimiter="\t",
      skiprows=0, dtype=np.float32)
    y_tmp = np.loadtxt(src_file, max_rows=num_rows,
      usecols=7, delimiter="\t", skiprows=0,
      dtype=np.long)

    self.x_data = T.tensor(x_tmp,
      dtype=T.float32).to(device)
    self.y_data = T.tensor(y_tmp,
      dtype=T.long).to(device)

  def __len__(self):
    return len(self.x_data)  # required

  def __getitem__(self, idx):
    if T.is_tensor(idx):
      idx = idx.tolist()
    preds = self.x_data[idx, 0:7]
    pol = self.y_data[idx]
    sample = \
      { 'predictors' : preds, 'political' : pol }
    return sample

# ---------------------------------------------------

def main():
  print("\nBegin PyTorch DataLoader demo ")

  # 0. miscellaneous prep
  T.manual_seed(0)
  np.random.seed(0)

  print("\nSource data looks like: ")
  print("1 0  0.171429  1 0 0  0.966805  0")
  print("0 1  0.085714  0 1 0  0.188797  1")
  print(" . . . ")

  # 1. create Dataset and DataLoader object
  print("\nCreating Dataset and DataLoader ")

  train_file = "people_train.txt"
  train_ds = PeopleDataset(train_file, num_rows=8)

  bat_size = 3
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True)

  # 2. iterate thru training data twice
  for epoch in range(2):
    print("\n==============================\n")
    print("Epoch = " + str(epoch))
    for (batch_idx, batch) in enumerate(train_ldr):
      print("\nBatch = " + str(batch_idx))
      X = batch['predictors']  # [3,7]
      # Y = T.flatten(batch['political'])  # 
      Y = batch['political']   # [3]
      print(X)
      print(Y)
  print("\n==============================")

  print("\nEnd demo ")

if __name__ == "__main__":
  main()

The code is simply saved with the filename " demo.py ".代码只是以文件名“ demo.py ”保存。 The code should succesfully execute once the command ' python demo.py ' is executed on a command prompt screen.在命令提示符屏幕上执行命令“ python demo.py ”后,代码应该成功执行。 I use Anaconda Prompt which has Torch (v 1.10) installed.我使用安装了 Torch (v 1.10) 的 Anaconda Prompt。

I have tried numerous methods to get the above working, but I only get an error which says:我已经尝试了多种方法来使上述工作正常进行,但我只收到一个错误消息:

Source data looks like: 
1 0  0.171429  1 0 0  0.966805  0
0 1  0.085714  0 1 0  0.188797  1
 . . . 

Creating Dataset and DataLoader 

---------------------------------------------------------------------------

IndexError                                Traceback (most recent call last)

<ipython-input-8-cfb1177991f2> in <module>()
     81 
     82 if __name__ == "__main__":
---> 83   main()

4 frames

<ipython-input-8-cfb1177991f2> in main()
     59 
     60   train_file = "people_train.txt"
---> 61   train_ds = PeopleDataset(train_file, num_rows=8)
     62 
     63   bat_size = 3

<ipython-input-8-cfb1177991f2> in __init__(self, src_file, num_rows)
     20     x_tmp = np.loadtxt(src_file, max_rows=num_rows,
     21       usecols=range(0,7), delimiter="\t",
---> 22       skiprows=0, dtype=np.float32)
     23     y_tmp = np.loadtxt(src_file, max_rows=num_rows,
     24       usecols=7, delimiter="\t", skiprows=0,

/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py in loadtxt(fname, dtype, comments, delimiter, converters, skiprows, usecols, unpack, ndmin, encoding, max_rows)
   1137         # converting the data
   1138         X = None
-> 1139         for x in read_data(_loadtxt_chunksize):
   1140             if X is None:
   1141                 X = np.array(x, dtype)

/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py in read_data(chunk_size)
   1058                 continue
   1059             if usecols:
-> 1060                 vals = [vals[j] for j in usecols]
   1061             if len(vals) != N:
   1062                 line_num = i + skiprows + 1

/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py in <listcomp>(.0)
   1058                 continue
   1059             if usecols:
-> 1060                 vals = [vals[j] for j in usecols]
   1061             if len(vals) != N:
   1062                 line_num = i + skiprows + 1

IndexError: list index out of range

I am not able to see which part of the index is wrong, as I don't feel there seem to be anything wrong with the indexing.我无法看到索引的哪一部分是错误的,因为我觉得索引似乎没有任何问题。 Can someone please help me?有人可以帮帮我吗?

Your data seems to be space-separated, not tab-separated.您的数据似乎是空格分隔的,而不是制表符分隔的。 So, when you specify delimiter="\t" , the entire row is read as a single column.因此,当您指定delimiter="\t"时,整行将作为单列读取。 But because of usecols=range(0,7) , NumPy expects there to be seven columns, and throws an error when trying to iterate over them.但是由于usecols=range(0,7) , NumPy 预计会有七列,并在尝试迭代它们时抛出错误。

To fix this, either change the whitespaces to tabs in your data, or change the delimiter argument to delimiter=" " .要解决此问题,请将空格更改为数据中的制表符,或将 delimiter 参数更改为delimiter=" "

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM