繁体   English   中英

Pytorch NN和类之间的通信

[英]Pytorch NN and communication between classes

我是python和pytorch的新手,在理解其工作方式时遇到问题。

    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim

    class Net(nn.Module):
        def __init__(self):
            ..
        def forward(self, x):
            ..
            return x
    net = Net()


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

在此处输入图片说明

因此,这就是代码,我从图片中的代码中得出了我所了解的内容。 我有一些疑问:

A)为什么我不能在代码中直接使用nn.CrossEntropy而不是“ criterion”? 如果将其分配给变量,会有什么区别? 我收到此错误:Tensor的布尔值具有多个值是不明确的

B)为什么当Net类获得一个对象(nn)(我假设使用'as'时,会创建一个对象),那么Net类随后可以简单地向后使用吗? 它应该是nn的一部分,而不是Net。 你能帮我理解一下吗?

C)虽然优化是一个不同的对象,但是优化优化的参数如何影响nn? 我不明白他们如何传递变量并互相更新?

A)通过在一个位置将其设置为变量,可以帮助您更轻松地在一个位置更改损失函数,而不必在许多地方键入nn.MSELoss,因为代码的大小和复杂性不断增加。 基本上不会出错。
至于错误,将需要更多信息来解决该布尔错误。 输入的内容在哪一行等等。那里的信息很少。

B)Net(nn.Module)从nn.Module继承,它将向后添加到您添加到该类的所有操作中。 有关更多信息,请参阅文档

C)“网”是一个对象。 net.parameters()是一个迭代器,它迭代网络对象中的所有参数。 因此,它通过引用传递,而不是通过值传递参数。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM