Is there an equivalent of BoolTensor from Pytorch in Tensorflow assuming I have the below usage in Pytorch that I want to migrate to Tensorflow
done_mask = torch.BoolTensor(dones.values).to(device)
next_state_values[done_mask] = 0.0
What is dones
? Assuming it's a 0/1 tensor, you can convert it to a Bool tensor like this:
tf.cast(dones,tf.bool)
However, if you want to assign values to a tensor, you can't do it that way.
A way, which I recommend, is to multiply by a matrix of 1/0:
next_state_values *= tf.cast(dones!=1,next_state_values.dtype)
Another way, that I don't recommend as it gives some issues when using the gradient, is to use tf.tensor_scatter_nd_update. For your case, that would be:
indices = tf.where(dones==1)
next_state_values = tf.tensor_scatter_nd_update(next_state_values ,indices,2*tf.zeros(len(indices)))
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.