简体   繁体   English

加快自定义消息丢失的 pytorch 操作

[英]Speeding up pytorch operations for custom message dropout

I am trying to implement message dropout in my custom MessagePassing convolution in PyTorch Geometric.我正在尝试在 PyTorch Geometric 中的自定义 MessagePassing 卷积中实现消息丢失。 Message dropout consists of randomly ignoring p% of the edges in the graph.消息丢失包括随机忽略图中 p% 的边。 My idea was to randomly remove p% of them from the input edge_index in forward() .我的想法是从forward()的输入edge_index中随机删除其中的 p%。

The edge_index is a tensor of shape (2, num_edges) where the 1st dimension is the "from" node ID and the 2nd is the "to" node ID". So what I thought I could do is select a random sample of range(N) and then use it to mask out the rest of the indices: edge_index是形状为(2, num_edges)的张量,其中第一维是“从”节点 ID,第二维是“到”节点 ID。所以我认为我可以做的是 select range(N)然后用它来掩盖索引的 rest:

    def forward(self, x, edge_index, edge_attr=None):
        if self.message_dropout is not None:
            # TODO: this is way too slow (4-5 times slower than without it)
            # message dropout -> randomly ignore p % of edges in the graph i.e. keep only (1-p) % of them
            random_keep_inx = random.sample(range(edge_index.shape[1]), int((1.0 - self.message_dropout) * edge_index.shape[1]))
            edge_index_to_use = edge_index[:, random_keep_inx]
            edge_attr_to_use = edge_attr[random_keep_inx] if edge_attr is not None else None
        else:
            edge_index_to_use = edge_index
            edge_attr_to_use = edge_attr

        ...

However, it is way too slow, it makes an epoch go for 5' instead of 1' without (5 times slower).然而,它太慢了,它使一个纪元 go 成为 5' 而不是 1' 没有(慢 5 倍)。 Is there a faster way to do this in PyTorch?在 PyTorch 中有更快的方法吗?

Edit: The bottleneck seems to be the random.sample() call, not the masking.编辑:瓶颈似乎是random.sample()调用,而不是屏蔽。 So I guess what I should be asking is for faster alternatives to that.所以我想我应该问的是更快的替代方案。

I managed to create a boolean mask using PyTorch's Dropout from Functional which is much faster.我设法使用 PyTorch 的函数式 Dropout 创建了一个 boolean 掩码,速度要快得多。 Now an epoch takes ~1' again.现在一个纪元又需要 ~1'。 Better than other solutions with permutations that I found elsewhere.比我在其他地方找到的具有排列的其他解决方案更好。

    def forward(self, x, edge_index, edge_attr=None):
        if self.message_dropout is not None:
            # message dropout -> randomly ignore p % of edges in the graph
            mask = F.dropout(torch.ones(edge_index.shape[1]), self.message_dropout, self.training) > 0
            edge_index_to_use = edge_index[:, mask]
            edge_attr_to_use = edge_attr[mask] if edge_attr is not None else None
        else:
            edge_index_to_use = edge_index
            edge_attr_to_use = edge_attr

        ...

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM