[英]Deep Q-Learning : torch.nn.functional.softmax crash
我正在學習一個教程,當我使用它時,函數softmax會崩潰。
newSignals = [0.5, 0., 0., -0.7911, 0.7911]
newState = torch.Tensor(newSignals).float().unsqueeze(0)
probs = F.softmax(self.model(newState), dim=1)
self.model
是一個神經網絡( torch.nn.module
),返回Tensor之類的
tensor([[ 0.2699, -0.2176, 0.0333]], grad_fn=<AddmmBackward>)
因此,線probs = F.softmax(self.model(newState), dim=1)
使程序崩潰但是當dim=0
它可以工作,但它不好。
免責聲明:對不起,這可能應該是評論,但我不能在評論中寫下以下所有內容。
你確定這是問題嗎? 下面的片段對我來說很有用。
import torch
a = torch.tensor([[ 0.2699, -0.2176, 0.0333]])
a.softmax(dim=1)
> tensor([[0.4161, 0.2555, 0.3284]])
a.softmax(dim=0)
> tensor([[1., 1., 1.]])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.