簡體   English   中英

Pytorch 的 share_memory_() 與內置 Python 的 shared_memory:為什么在 Pytorch 中我們不需要訪問共享內存塊?

[英]Pytorch's share_memory_() vs built-in Python's shared_memory: Why in Pytorch we don't need to access the shared memory-block?

在嘗試了解內置multiprocessingPytorch 的multiprocessing包時,我觀察到兩者之間存在不同的行為。 我覺得這很奇怪,因為Pytorch 的 package 與內置的 package 完全兼容

具體來說,我指的是進程之間共享變量的方式。 在 Pytorch 中,張量通過就地操作share_memory_()移動到 shared_memory。 另一方面,我們可以通過使用shared_memory模塊獲得與內置 package 相同的結果。

我很難理解的兩者之間的區別在於,對於內置版本,我們必須在啟動的進程中明確訪問共享內存塊。 但是,我們不需要對 Pytorch 版本執行此操作。

這是一個Pytorch的玩具示例,展示了這一點:

import time

import torch
# the same behavior happens when importing:
# import multiprocessing as mp
import torch.multiprocessing as mp


def get_time(s):
    return round(time.time() - s, 1)


def foo(a):
    # wait ~1sec to print the value of the tensor.
    time.sleep(1.0)
    with lock:
        #-------------------------------------------------------------------
        # WITHOUT explicitely accessing the shared memory block, we can observe
        # that the tensor has changed:
        #-------------------------------------------------------------------
        print(f"{__name__}\t{get_time(s)}\t\t{a}")


# global variables.
lock = mp.Lock()
s = time.time()


if __name__ == '__main__':
    print("Module\t\tTime\t\tValue")
    print("-"*50)

    # create tensor and assign it to shared memory.
    a = torch.zeros(2).share_memory_()
    print(f"{__name__}\t{get_time(s)}\t\t{a}")

    # start child process.
    p0 = mp.Process(target=foo, args=(a,))
    p0.start()

    # modify the value of the tensor after ~0.5sec.
    time.sleep(0.5)
    with lock:
        a[0] = 1.0

    print(f"{__name__}\t{get_time(s)}\t\t{a}")
    time.sleep(1.5)

    p0.join()

輸出(如預期):

Module          Time            Value
--------------------------------------------------
__main__        0.0             tensor([0., 0.])
__main__        0.5             tensor([1., 0.])
__mp_main__     1.0             tensor([1., 0.])

這是一個帶有內置package 的玩具示例:

import time
import multiprocessing as mp
from multiprocessing import shared_memory

import numpy as np


def get_time(s):
    return round(time.time() - s, 1)


def foo(shm_name, shape, type_):
    #-------------------------------------------------------------------
    # WE NEED TO explicitely access the shared memory block to observe
    # that the array has changed:
    #-------------------------------------------------------------------
    existing_shm = shared_memory.SharedMemory(name=shm_name)
    a = np.ndarray(shape, type_, buffer=existing_shm.buf)

    # wait ~1sec to print the value.
    time.sleep(1.0)
    with lock:
        print(f"{__name__}\t{get_time(s)}\t\t{a}")


# global variables.
lock = mp.Lock()
s = time.time()


if __name__ == '__main__':
    print("Module\t\tTime\t\tValue")
    print("-"*35)

    # create numpy array and shared memory block.
    a = np.zeros(2,)
    shm = shared_memory.SharedMemory(create=True, size=a.nbytes)
    a_shared = np.ndarray(a.shape, a.dtype, buffer=shm.buf)
    a_shared[:] = a[:]
    print(f"{__name__}\t{get_time(s)}\t\t{a_shared}")

    # start child process.
    p0 = mp.Process(target=foo, args=(shm.name, a.shape, a.dtype))
    p0.start()

    # modify the value of the vaue after ~0.5sec.
    time.sleep(0.5)
    with lock:
        a_shared[0] = 1.0

    print(f"{__name__}\t{get_time(s)}\t\t{a_shared}")
    time.sleep(1.5)

    p0.join()

正如預期的那樣,它等效地輸出:

Module          Time            Value
-----------------------------------
__main__        0.0             [0. 0.]
__main__        0.5             [1. 0.]
__mp_main__     1.0             [1. 0.]

所以我想要理解的是,為什么我們不需要在內置版本和 Pytorch 版本中都遵循相同的步驟,即 Pytorch 如何能夠避免顯式訪問共享內存塊的需要?

PS 我使用的是 Windows 操作系統和 Python 3.9

您正在給 pytorch 作者寫一封情書。 也就是說,您是在拍拍他們的背,祝賀他們的包裝工作“做得很好”。 這是一個可愛的圖書館。

讓我們退后一步,使用一個非常簡單的數據結構,字典d 如果 parent 用一些值初始化d ,然后啟動一對 worker children,每個 child 都有一個d的副本。

那是怎么發生的? multiprocessing模塊從 worker 中分離出來,查看包含d的已定義變量集,並將那些(鍵,值)對從父級向下序列化到子級。

所以此時我們有 3 個獨立的d副本 如果父母或任何一個孩子修改d ,則其他 2 個副本完全不受影響。

現在切換到 pytorch 包裝器。 您提供了一些漂亮簡潔的代碼來演示 little.SharedMemory() 如果我們想要 3 個對相同共享結構的引用而不是 3 個獨立副本,則應用程序需要執行此操作。 pytorch 包裝器序列化對公共數據結構的引用,而不是生成副本。 在引擎蓋下,它正在做你所做的舞蹈。 但是在應用程序級別沒有重復的冗長,因為細節已經被很好地抽象掉了,FTW!

為什么在 Pytorch 中我們不需要訪問共享內存塊?

tl;dr:我們確實需要訪問它。 但是圖書館承擔了擔心細節的負擔,所以我們不必擔心。

pytorch 有一個圍繞共享 memory 的簡單包裝器,python 的共享 memory 模塊只是對底層操作系統相關函數的包裝器。

可以做到的方法是您不序列化數組或共享的 memory 本身,而僅使用文檔中的__getstate____setstate__方法序列化創建它們所需的內容,以便您的 object 充當代理和同時一個容器。

以下bar class 可以通過這種方式兼作代理和容器,如果用戶不必擔心共享的 memory 部分,這將很有用。

import time
import multiprocessing as mp
from multiprocessing import shared_memory
import numpy as np

class bar:
    def __init__(self):
        self._size = 10
        self._type = np.uint8
        self.shm = shared_memory.SharedMemory(create=True, size=self._size)
        self._mem_name = self.shm.name
        self.arr = np.ndarray([self._size], self._type, buffer=self.shm.buf)

    def __getstate__(self):
        """Return state values to be pickled."""
        return (self._mem_name, self._size, self._type)

    def __setstate__(self, state):
        """Restore state from the unpickled state values."""
        self._mem_name, self._size, self._type = state
        self.shm = shared_memory.SharedMemory(self._mem_name)
        self.arr = np.ndarray([self._size], self._type, buffer=self.shm.buf)

def get_time(s):
    return round(time.time() - s, 1)

def foo(shm, lock):
    # -------------------------------------------------------------------
    # without explicitely access the shared memory block we observe
    # that the array has changed:
    # -------------------------------------------------------------------
    a = shm

    # wait ~1sec to print the value.
    time.sleep(1.0)
    with lock:
        print(f"{__name__}\t{get_time(s)}\t\t{a.arr}")

# global variables.
s = time.time()

if __name__ == '__main__':
    lock = mp.Lock()  # to work on windows/mac.

    print("Module\t\tTime\t\tValue")
    print("-" * 35)

    # create numpy array and shared memory block.
    a = bar()
    print(f"{__name__}\t{get_time(s)}\t\t{a.arr}")

    # start child process.
    p0 = mp.Process(target=foo, args=(a, lock))
    p0.start()

    # modify the value of the vaue after ~0.5sec.
    time.sleep(0.5)
    with lock:
        a.arr[0] = 1.0

    print(f"{__name__}\t{get_time(s)}\t\t{a.arr}")
    time.sleep(1.5)

    p0.join()

python 只是讓在 class 中隱藏這些細節變得更加容易,而不會用這些細節打擾用戶。

編輯:我希望他們能讓鎖不可繼承,這樣你的代碼就可以在鎖上引發錯誤,相反你有一天會發現它實際上並沒有鎖......在它使你的應用程序在生產中崩潰之后。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM