简体   繁体   中英

How to convert a pytorch tensor into a numpy array?

I have a torch tensor

a = torch.randn(1, 2, 3, 4, 5)

How can I get it in numpy?

Something like

b = a.tonumpy()

output should be the same as if I did

b = np.random.randn(1, 2, 3, 4, 5)

copied from pytorch doc :

a = torch.ones(5)

tensor([1., 1., 1., 1., 1.])

b = a.numpy()

[1. 1. 1. 1. 1.]

Following from the below discussion with @John:

In case the tensor is (or can be) on GPU, or in case it (or it can) require grad, one can use


I recommend to uglify your code only as much as required.

You can try following ways

1. torch.Tensor().numpy()
2. torch.Tensor().cpu().data.numpy()
3. torch.Tensor().cpu().detach().numpy()

Another useful way :

a = torch(0.1, device='cuda')



array(0.1, dtype=float32)

This is a function from fastai core :

def to_np(x):
    "Convert a tensor to a numpy array."
    return apply(lambda o: o.data.cpu().numpy(), x)

Possible using a function from prospective PyTorch library is a nice choice.

If you look inside PyTorch Transformers you will find this code :

preds = logits.detach().cpu().numpy()

So you may ask why the detach() method is needed? It is needed when we would like to detach the tensor from AD computational graph.

Still note that the CPU tensor and numpy array are connected . They share the same storage:

import torch
tensor = torch.zeros(2)
numpy_array = tensor.numpy()
print('Before edit:')

tensor[0] = 10

print('After edit:')
print('Tensor:', tensor)
print('Numpy array:', numpy_array)


Before edit:
tensor([0., 0.])
[0. 0.]

After edit:
Tensor: tensor([10.,  0.])
Numpy array: [10.  0.]

The value of the first element is shared by the tensor and the numpy array. Changing it to 10 in the tensor changed it in the numpy array as well.

This is why we need to be careful, since altering the numpy array my alter the CPU tensor as well.

You may find the following two functions useful.

  1. torch.Tensor.numpy()
  2. torch.from_numpy()

Sometimes if there's "applied" gradient, you'll first have to put .detach() function before the .numpy() function.

loss = loss_fn(preds, labels)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM