简体   繁体   中英

How to convert a pytorch tensor of ints to a tensor of booleans?

I would like to cast a tensor of ints to a tensor of booleans.

Specifically I would like to be able to have a function which transforms tensor([0,10,0,16]) to tensor([0,1,0,1])

This is trivial in Tensorflow by just using tf.cast(x,tf.bool) .

I want the cast to change all ints greater than 0 to a 1 and all ints equal to 0 to a 0. This is the equivalent of !! in most languages.

Since pytorch does not seem to have a dedicated boolean type to cast to, what is the best approach here?

Edit: I am looking for a vectorized solution opposed to looping through each element.

What you're looking for is to generate a boolean mask for the given integer tensor. For this, you can simply check for the condition: "whether the values in the tensor are greater than 0" using simple comparison operator ( > ) or usingtorch.gt() , which would then give us the desired result.

# input tensor
In [76]: t   
Out[76]: tensor([ 0, 10,  0, 16])

# generate the needed boolean mask
In [78]: t > 0      
Out[78]: tensor([0, 1, 0, 1], dtype=torch.uint8)

# sanity check
In [93]: mask = t > 0      

In [94]: mask.type()      
Out[94]: 'torch.ByteTensor'

Note : In PyTorch version 1.4+, the above operation would return 'torch.BoolTensor'

In [9]: t > 0  
Out[9]: tensor([False,  True, False,  True])

# alternatively, use `torch.gt()` API
In [11]: torch.gt(t, 0)
Out[11]: tensor([False,  True, False,  True])

If you indeed want single bits (either 0 s or 1 s), cast it using:

In [14]: (t > 0).type(torch.uint8)   
Out[14]: tensor([0, 1, 0, 1], dtype=torch.uint8)

# alternatively, use `torch.gt()` API
In [15]: torch.gt(t, 0).int()
Out[15]: tensor([0, 1, 0, 1], dtype=torch.int32)

The reason for this change has been discussed in this feature-request issue: issues/4764 - Introduce torch.BoolTensor ...


TL;DR : Simple one liner

t.bool().int()

You can use comparisons as shown below:

 >>> a = tensor([0,10,0,16])
 >>> result = (a == 0)
 >>> result
 tensor([ True, False,  True, False])

Convert boolean to number value:

 a = torch.tensor([0,4,0,0,5,0.12,0.34,0,0]) print(a.gt(0)) # output in boolean dtype # output: tensor([False, True, False, False, True, True, True, False, False]) print(a.gt(0).to(torch.float32)) # output in float32 dtype # output: tensor([0., 1., 0., 0., 1., 1., 1., 0., 0.])

Another option would be to simply do:

temp = torch.tensor([0,10,0,16])
temp.bool()
#Returns
tensor([False,  True, False,  True])

PyTorch's to(dtype) method has convenient data-type named aliases . You can simply call bool :

>>> t.bool()
tensor([False,  True, False,  True])
>>> t.bool().int()
tensor([0, 1, 0, 1], dtype=torch.int32)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM