繁体   English   中英

打印张量的所有内容

[英]Printing all the contents of a tensor

我遇到了这个PyTorch 教程(在 neural_networks_tutorial.py 中),他们在其中构建了一个简单的神经网络并运行了推理。 我想打印整个输入张量的内容以进行调试。 当我尝试打印张量时,我得到的是这样的,而不是整个张量:

在此处输入图片说明

我看到了一个类似的 numpy 链接,但不确定什么对 PyTorch 有用。 我可以将它转换为 numpy 并可以查看它,但希望避免额外的开销。 有没有办法打印整个张量?

为了避免截断并控制打印多少张量数据,请使用与 numpy 的numpy.set_printoptions(threshold=10_000)相同的 API。

例子:

x = torch.rand(1000, 2, 2)
print(x) # prints the truncated tensor
torch.set_printoptions(threshold=10_000)
print(x) # prints the whole tensor

如果您的张量非常大,请将threshold值调整为更高的数字。

另一种选择是:

torch.set_printoptions(profile="full")
print(x) # prints the whole tensor
torch.set_printoptions(profile="default") # reset
print(x) # prints the truncated tensor

此处记录所有可用的set_printoptions参数。

虽然我不建议这样做,如果你愿意,那么

In [18]: torch.set_printoptions(edgeitems=1)

In [19]: a
Out[19]:
tensor([[-0.7698,  ..., -0.1949],
        ...,
        [-0.7321,  ...,  0.8537]])

In [20]: torch.set_printoptions(edgeitems=3)

In [21]: a
Out[21]:
tensor([[-0.7698,  1.3383,  0.5649,  ...,  1.3567,  0.6896, -0.1949],
        [-0.5761, -0.9789, -0.2058,  ..., -0.5843,  2.6311, -0.0008],
        [ 1.3152,  1.8851, -0.9761,  ...,  0.8639, -0.6237,  0.5646],
        ...,
        [ 0.2851,  0.5504, -0.9471,  ...,  0.0688, -0.7777,  0.1661],
        [ 2.9616, -0.8685, -1.5467,  ..., -1.4646,  1.1098, -1.0873],
        [-0.7321,  0.7610,  0.3182,  ...,  2.5859, -0.9709,  0.8537]])

我来到这里实际上是在寻找如何在控制台的一行中打印整行张量的答案,所以我想我会添加这个。

tensor([[1.1573e+04, 6.0693e+02, 1.2436e+03, 2.7277e+04, 1.6673e+08, 2.0462e+00, 9.8891e-01],
    [2.0237e+04, 5.9074e+02, 1.7208e+03, 2.7449e+04, 2.1301e+08, 2.0678e+00, 1.0011e+00],
    [2.7456e+04, 6.1106e+02, 1.4897e+03, 2.7332e+04, 1.7310e+08, 2.0448e+00, 9.6041e-01],
    [1.7732e+04, 6.0232e+02, 1.2608e+03, 2.7371e+04, 1.8106e+08, 1.9594e+00, 1.0040e+00],
    ...,
    [1.1167e+04, 5.9867e+02, 1.3440e+03, 2.7263e+04, 2.3160e+08, 2.0190e+00, 1.0075e+00],
    [1.6003e+04, 5.9590e+02, 1.2319e+03, 2.7368e+04, 1.7155e+08, 2.0171e+00, 1.0202e+00],
    [1.5499e+04, 6.1471e+02, 9.4877e+02, 2.7395e+04, 1.8146e+08, 1.9016e+00, 9.5884e-01],
    [3.3886e+04, 6.0689e+02, 1.0777e+03, 2.7259e+04, 2.1599e+08, 2.0179e+00, 1.0201e+00]], dtype=torch.float64)

我这样做了

torch.set_printoptions(linewidth=200)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM