簡體   English   中英

加快自定義消息丟失的 pytorch 操作

[英]Speeding up pytorch operations for custom message dropout

我正在嘗試在 PyTorch Geometric 中的自定義 MessagePassing 卷積中實現消息丟失。 消息丟失包括隨機忽略圖中 p% 的邊。 我的想法是從forward()的輸入edge_index中隨機刪除其中的 p%。

edge_index是形狀為(2, num_edges)的張量,其中第一維是“從”節點 ID,第二維是“到”節點 ID。所以我認為我可以做的是 select range(N)然后用它來掩蓋索引的 rest:

    def forward(self, x, edge_index, edge_attr=None):
        if self.message_dropout is not None:
            # TODO: this is way too slow (4-5 times slower than without it)
            # message dropout -> randomly ignore p % of edges in the graph i.e. keep only (1-p) % of them
            random_keep_inx = random.sample(range(edge_index.shape[1]), int((1.0 - self.message_dropout) * edge_index.shape[1]))
            edge_index_to_use = edge_index[:, random_keep_inx]
            edge_attr_to_use = edge_attr[random_keep_inx] if edge_attr is not None else None
        else:
            edge_index_to_use = edge_index
            edge_attr_to_use = edge_attr

        ...

然而,它太慢了,它使一個紀元 go 成為 5' 而不是 1' 沒有(慢 5 倍)。 在 PyTorch 中有更快的方法嗎?

編輯:瓶頸似乎是random.sample()調用,而不是屏蔽。 所以我想我應該問的是更快的替代方案。

我設法使用 PyTorch 的函數式 Dropout 創建了一個 boolean 掩碼,速度要快得多。 現在一個紀元又需要 ~1'。 比我在其他地方找到的具有排列的其他解決方案更好。

    def forward(self, x, edge_index, edge_attr=None):
        if self.message_dropout is not None:
            # message dropout -> randomly ignore p % of edges in the graph
            mask = F.dropout(torch.ones(edge_index.shape[1]), self.message_dropout, self.training) > 0
            edge_index_to_use = edge_index[:, mask]
            edge_attr_to_use = edge_attr[mask] if edge_attr is not None else None
        else:
            edge_index_to_use = edge_index
            edge_attr_to_use = edge_attr

        ...

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM