簡體   English   中英

使用帶有 numba njit 函數的字典

[英]Using Dictionaries with numba njit function

當輸入和返回是字典時,如何使用 numba 加速功能?

我熟悉將 numba 用於接受數字和返回數組的函數,如下所示:

@numba.jit('float64[:](int32,int32)',nopython=True)
def f(a, b):
    # returns array 1d array

現在我有一個接受並返回字典的函數。 我如何在這里申請 numba?

    def collocation(aeolus_data,val_data):

      ...

      return sample_aeolus, sample_valdata

現在 Numba 版本43.0添加了對 Dictionary 的支持。 雖然它非常有限(不支持列表和設置為鍵/值)。 但是,您可以在此處閱讀更新的文檔以獲取更多信息 這是一個例子

import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict

# First create a dictionary using Dict.empty()
# Specify the data types for both key and value pairs

# Dict with key as strings and values of type float array
dict_param1 = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64[:],
)

# Dict with keys as string and values of type float
dict_param2 = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64,
)

# Type-expressions are currently not supported inside jit functions.
float_array = types.float64[:]

@njit
def add_values(d_param1, d_param2):
    # Make a result dictionary to store results
    # Dict with keys as string and values of type float array
    result_dict = Dict.empty(
        key_type=types.unicode_type,
        value_type=float_array,
    )

    for key in d_param1.keys():
      result_dict[key] = d_param1[key] + d_param2[key]

    return result_dict

dict_param1["hello"]  = np.asarray([1.5, 2.5, 3.5], dtype='f8')
dict_param1["world"]  = np.asarray([10.5, 20.5, 30.5], dtype='f8')

dict_param2["hello"]  = 1.5
dict_param2["world"]  = 10

final_dict = add_values(dict_param1, dict_param2)

print(final_dict)
# Output : {hello: [3. 4. 5.], world: [20.5 30.5 40.5]}

鏈接到 Google colab 筆記本

參考:
- https://github.com/numba/numba/issues/3644
- https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#dict

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM