簡體   English   中英

Pytorch 如何增加批量大小

[英]Pytorch how to increase batch size

我目前有一個 torch.Size([1, 3, 256, 224]) 的張量,但我需要它作為輸入形狀 [32, 3, 256, 224]。 我正在實時捕獲數據,因此數據加載器似乎不是一個好的選擇。 有沒有簡單的方法來獲取 32 個大小的 torch.Size([1, 3, 256, 224]) 並將它們組合起來以創建 1 個大小為 [32, 3, 256, 224] 的張量?

您可能會使用 jit model,並且批次大小必須與訓練 model 的批次大小完全相同。

t = torch.rand(1, 3, 256, 224)
t.size() # torch.Size([1, 3, 256, 224])
t2= t.expand(32, -1,-1,-1)
t2.size() # torch.Size([32, 3, 256, 224])

擴展張量不會分配新的 memory,而只會在現有張量上創建一個新視圖,然后您就可以得到所需的東西。 只有張量步幅發生了變化。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM