繁体   English   中英

如何检查 NamedTuple 是否在列表中?

[英]How to check if NamedTuple is in list?

我试图检查 NamedTuple "Transition" 的 inctance 是否等于列表 "self.memory" 中的任何对象。

这是我尝试运行的代码:

from typing import NamedTuple
import random
import torch as t

Transition = NamedTuple('Transition', state=t.Tensor, action=int, reward=int, next_state=t.Tensor, done=int, hidden=t.Tensor)


class ReplayMemory:

    def __init__(self, capacity):
        self.memory = []
        self.capacity = capacity
        self.position = 0

    def store(self, *args):
        print(self.memory == Transition(*args))
        if Transition(*args) in self.memory:
            return
    if len(self.memory) < self.capacity:
        self.memory.append(None)
    self.memory[self.position] = Transition(*args)
    ...

这是输出:

False
False

我得到的错误是:

   ...
        if Transition(*args) in self.memory:
    RuntimeError: bool value of Tensor with more than one value is ambiguous

这对我来说似乎很奇怪,因为打印告诉我“==”操作返回一个布尔值。

这怎么能正确完成?

谢谢

编辑:

*args 是一个元组,由

torch.Size([16, 12])
int
int
torch.Size([16, 12])
int
torch.Size([4])

我相信你应该明确定义平等。

from typing import NamedTuple
import random
import torch as t


class Sample(NamedTuple):
    state: t.Tensor
    action: int

    def __eq__(self, other):
        return bool(t.all(self.state == other.state)) and self.action == other.action

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM