[英]How to debug if weight keep increasing. Pytorch program
我在练习 Pytorch 程序时有些疑问。
我有像 y = m1x1 + m2x2 + c 这样的函数(这里只需要学习 2 个权重)。 权重的期望值应该是 16,-14,偏差应该是 36。但是在每个 epoch 中,学习到的权重都会变得非常大。 谁能帮我调试和理解这20行代码,这里出了什么问题。
import torch
x = torch.randint(size = (1,2), high = 10)
w = torch.Tensor([16,-14])
b = 36
#Compute Ground Truth
y = w * x + b
#Find weights by program
epoch = 20
learning_rate = 30
#initialize random
w1 = torch.rand(size= (1,2), requires_grad= True)
b1 = torch.ones(size = [1], requires_grad= True)
for i in range(epoch):
y1 = w1 * x + b1
#loss function RMSQ
loss = torch.sum((y1-y)**2)
#Find gradient
loss.backward()
with torch.no_grad():
#update parameters
w1 -= (learning_rate * w1.grad)
b1 -= (learning_rate * b1.grad)
w1.grad.zero_()
b1.grad.zero_()
print("B ", b1)
print("W ", w1)
谢谢,加内什
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.