I tried to check if an inctance of a NamedTuple "Transition" is equal to any object in the list "self.memory".
Here is the code I tried to run:
from typing import NamedTuple
import random
import torch as t
Transition = NamedTuple('Transition', state=t.Tensor, action=int, reward=int, next_state=t.Tensor, done=int, hidden=t.Tensor)
class ReplayMemory:
def __init__(self, capacity):
self.memory = []
self.capacity = capacity
self.position = 0
def store(self, *args):
print(self.memory == Transition(*args))
if Transition(*args) in self.memory:
return
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args)
...
And here is the output:
False
False
And the error I got:
...
if Transition(*args) in self.memory:
RuntimeError: bool value of Tensor with more than one value is ambiguous
This seems weird to me because the print is telling me that the "==" operation returns a boolean.
How could this be done correctly?
Thank you
Edit:
*args is a tuple that consists of
torch.Size([16, 12])
int
int
torch.Size([16, 12])
int
torch.Size([4])
I believe the you should explicitly define equality.
from typing import NamedTuple
import random
import torch as t
class Sample(NamedTuple):
state: t.Tensor
action: int
def __eq__(self, other):
return bool(t.all(self.state == other.state)) and self.action == other.action
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.