繁体   English   中英

从 Pandas 数据帧为 pytorch lstm 准备数据的最有效方法

[英]Most efficient way to prep data for pytorch lstm from pandas dataframe

我正在尝试将数据输入我的 lstm。 我有多个 csv 的数据,所以我创建了一个生成器来加载它。但是,我在为我的 lstm 准备数据时遇到了一些问题。

我理解这段代码(我从 pytorch 文档中得到它)

seq_len = 5
batch_size= 3
cols_num = 10
hidden_size=20
num_layers = 2

rnn = nn.LSTM(input_size=cols_num, hidden_size=hidden_size, num_layers=num_layers)

data = torch.randn(seq_len, batch_size, cols_num)
h0 = torch.randn(batch_size, seq_len, hidden_size)
c0 = torch.randn(batch_size, seq_len, hidden_size)
output, (hn, cn) = rnn(data)

但是,我认为我的脱节是使用实际数据而不是 torch.randn()。

这是我目前的发电机:

def data_loader(batch_size, fp, dropcol, seq_len):
    while True:
        for f in fp:
            gc.collect()
            df=pd.read_csv(f)
            df=df.replace(np.nan, 0)
            df=df.drop(dropcol,1)
            df['minute'] = df['minute'].apply(lambda x: min_idx(x))
            row_count, col_count = df.shape
            encoder_input = []
            prev = 0
            for idx, b in enumerate(range(1, row_count)):
                end = prev + batch_size
                window = df.iloc[prev:end]
                prev = end - 1
                w = np.array(window, dtype='float64')
                if w.shape[0] != batch_size:  break
                encoder_input.append(w)
                if idx == seq_len:
                    w0 = encoder_input
                    encoder_input = []
                    yield w0

但是当我运行这个时出现错误:

loader = data_loader(batch_size=batch_size, fp=<list of csvs>, dropcol=idcol, seq_len=2)
lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=batch_first)


for batch in loader:
    b = torch.tensor(batch)
    output, hidden = lstm(b)

错误: RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #4 'mat1'

我的想法有什么错误? 另外,我应该如何从数据中格式化 h0 或 c0?

错误不在于你如何思考,错误在于 Pytorch 模型如何接受输入。 在 Pytorch 张量中创建的默认数据类型是torch.float64 ,其中模型接受的默认(并且可能唯一的)数据类型是torch.float32

要修复此用途:

b = torch.tensor(batch, dtype=torch.float32)

这会将您的输入转换为torch.float32

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM