繁体   English   中英

dict.get 的向量化等价物

[英]Vectorized equivalent of dict.get

我正在寻找像这样运行的功能

lookup_dict = {5:1.0, 12:2.0, 39:2.0...}
# this is the missing magic:
lookup = vectorized_dict(lookup_dict)

x = numpy.array([5.0, 59.39, 39.49...])

xbins = numpy.trunc(x).astype(numpy.int_)  
y = lookup.get(xbins, 0.0)

# the idea is that we get this as the postcondition:    
for (result, input) in zip(y, xbins):
     assert(result==lookup_dict.get(input, 0.0))

在 numpy(或 scipy)中是否有一些稀疏数组的味道可以获得这种功能?

完整的上下文是我正在合并一维特征的一些样本。

使用np.select在数组上创建布尔掩码( [xbins == k for k in lookup_dict] )、dict 中的值( lookup_dict.values() )和默认值0

y = np.select(
    [xbins == k for k in lookup_dict], 
    lookup_dict.values(), 
    0.0
)
# In [17]: y
# Out[17]: array([1., 0., 2.])

这假设字典已排序,我不确定 python 3.6 以下的行为会是什么。

或对熊猫过度杀伤:

import pandas as pd
s = pd.Series(xbins)
s = s.map(lookup_dict).fillna(0)

据我所知,numpy 不支持相同数组结构中的不同数据类型,但如果您愿意将键与值分开并按排序顺序维护键(和相应的值),则可以获得类似的结果:

import numpy as np

keys   = np.array([5,12,39])
values = np.array([1.0, 2.0, 2.0])

valueOf5 = values[keys.searchsorted(5)] # 2.0


k = np.array([5,5,12,39,12]) 

values[keys.searchsorted(k)] # array([1., 1., 2., 2., 2.])

这可能不如散列键有效,但它确实支持从具有任意维数的数组传播间接。

请注意,这假设您的键始终存在于键数组中。 如果没有,而不是错误,您可能会从下一个键中获取值。

另一种方法是使用 searchsorted 搜索具有整数“键”的 numpy 数组,并返回范围n <= x < n+1的初始加载值。 这可能对将来提出类似问题的人有用。

import numpy as np

class NpIntDict:
    """ Class to simulate a python dict get for a numpy array.  """
    def __init__( self, dict_in, default = np.nan ):
        """  dict_in: a dictionary with integer keys.
             default: the value to be returned for keys not in the dictionary.
                      defaults to np.nan
             default must be consistent with the dtype of values
        """
        # Create list of dict items sorted by key.
        list_in = sorted([ item for item in dict_in.items() ])
        # Create three empty lists.
        key_list = []   
        val_list = [] 
        is_def_mask = []
        for key, value in list_in:
            key = int(key)
            if not key in key_list:   # key not yet in key list
                # Update the three lists for key as default.
                key_list.append( key )      
                val_list.append( default )
                is_def_mask.append( True )
            # Update the lists for key+1.  With searchsorted this gives the required results.
            key_list.append( key + 1 )
            val_list.append( value )
            is_def_mask.append( False )
        # Add the key > max(key) to the val and is_def_mask lists.
        val_list.append( default )
        is_def_mask.append( True )
        self.keys = np.array( key_list, dtype = np.int )
        self.values = np.array( val_list )
        self.default_mask = np.array( is_def_mask )

    def set_default( self, default = 0 ):
        """  Set the default to a new default value.  Using self.default_mask.
             Changes the default value for all future self.get(arr).
        """
        self.values[ self.default_mask ] = default

    def get( self, arr, default = None ):
        """  Returns an array looking up the values in `arr` in the dict.
             default can be used to change the default value returned for this get only.
        """
        if default is None:
            values = self.values
        else:
            values= self.values.copy()
            values[ self.default_mask ] = default
        return values[ np.searchsorted( self.keys, arr, side = 'right' ) ]
        # side = 'right' to ensure key[ix] <= x < key[ix+1]
        # side = 'left' would mean key[ix] < x <= key[ix+1]

如果在创建 NpIntDict 后不需要更改返回的默认值,则可以简化此操作。

来测试一下。

d = { 2: 5.1, 3: 10.2, 5: 47.1, 8: -6}

# x <2 Return default
# 2 <= x <3 return 5.1
# 3 <= x < 4 return 10.2
# 4 <= x < 5 return default
# 5 <= x < 6 return 47.1
# 6 <= x < 8 return default
# 8 <= x < 9 return -6.
# 9 <= x return default

test = NpIntDict( d, default = 0.0 )
arr = np.arange( 0., 100. ).reshape(10,10)/10
print( arr )
"""
[[0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
 [1.  1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9]
 [2.  2.1 2.2 2.3 2.4 2.5 2.6 2.7 2.8 2.9]
 [3.  3.1 3.2 3.3 3.4 3.5 3.6 3.7 3.8 3.9]
 [4.  4.1 4.2 4.3 4.4 4.5 4.6 4.7 4.8 4.9]
 [5.  5.1 5.2 5.3 5.4 5.5 5.6 5.7 5.8 5.9]
 [6.  6.1 6.2 6.3 6.4 6.5 6.6 6.7 6.8 6.9]
 [7.  7.1 7.2 7.3 7.4 7.5 7.6 7.7 7.8 7.9]
 [8.  8.1 8.2 8.3 8.4 8.5 8.6 8.7 8.8 8.9]
 [9.  9.1 9.2 9.3 9.4 9.5 9.6 9.7 9.8 9.9]]
"""

print( test.get( arr ) )
"""
[[ 0.   0.   0.   0.   0.   0.   0.   0.   0.   0. ]
 [ 0.   0.   0.   0.   0.   0.   0.   0.   0.   0. ]
 [ 5.1  5.1  5.1  5.1  5.1  5.1  5.1  5.1  5.1  5.1]
 [10.2 10.2 10.2 10.2 10.2 10.2 10.2 10.2 10.2 10.2]
 [ 0.   0.   0.   0.   0.   0.   0.   0.   0.   0. ]
 [47.1 47.1 47.1 47.1 47.1 47.1 47.1 47.1 47.1 47.1]
 [ 0.   0.   0.   0.   0.   0.   0.   0.   0.   0. ]
 [ 0.   0.   0.   0.   0.   0.   0.   0.   0.   0. ]
 [-6.  -6.  -6.  -6.  -6.  -6.  -6.  -6.  -6.  -6. ]
 [ 0.   0.   0.   0.   0.   0.   0.   0.   0.   0. ]]
"""

如果任何 arr 元素不在键列表中,则可以对其进行修改以引发异常。 对我来说,返回默认值会更有用。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM