简体   繁体   English

如何根据另一个 NumPy 数组的值创建 NumPy 数组?

[英]How to create a NumPy Array based on the values of another NumPy array?

I would like to create a NumPy array.我想创建一个 NumPy 数组。 The value of it's elements depends on the value of the elements in another NumPy array.它的元素值取决于另一个 NumPy 数组中元素的值。 Presently, I have to use a for-loop in a list comprehension to iterate through array a to get b .目前,我必须在列表推导中使用 for 循环来遍历数组a以获取b What is the NumPy way to achieve this?实现此目的的 NumPy 方法是什么?

Test Script:测试脚本:

import numpy as np

def get_b( a ):
    b_dict = {  1:10., 2:20., 3:30. }
    return b_dict[ a ]

a = np.full( 10, 2 )
print( f'a = {a}' )
b = np.array( [get_b(i) for i in a] )
print( f'b = {b}' )

Output: Output:

a = [2 2 2 2 2 2 2 2 2 2]
b = [20. 20. 20. 20. 20. 20. 20. 20. 20. 20.]

You can use np.vectorize to map a dictionary value to an array您可以使用np.vectorize将 map 字典值转换为数组

In [6]: b_dict = {  1:10., 2:20., 3:30 }

In [7]: a = np.full( 10, 2 )

In [8]: np.vectorize(b_dict.get)(a)
Out[8]: array([20., 20., 20., 20., 20., 20., 20., 20., 20., 20.])

What about using map and np.fromiter ?使用mapnp.fromiter怎么样?

def get_b( a ):
    b_dict = {  1:10., 2:20., 3:30. }
    return b_dict[ a ]

a = np.full( 10, 2 )
b = np.fromiter(map(get_b, a), dtype=np.float64)

Edit 1 : Small time comparison:编辑1 :小时间比较:

%timeit np.array( [get_b(i) for i in a] )
5.58 µs ± 123 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit np.fromiter(map(get_b, a), dtype=np.float64)
5.77 µs ± 177 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit np.vectorize(b_dict.get)(a)
12.9 µs ± 76.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

Edit 2 : Seems like that example is too small:编辑 2 :似乎该示例太小:

a = np.full( 1000, 2 )

%timeit np.array( [get_b(i) for i in a] )
415 µs ± 9.13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit np.fromiter(map(get_b, a), dtype=np.float64)
383 µs ± 2.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit np.vectorize(b_dict.get)(a)
68.6 µs ± 625 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Another approach to the problem:解决问题的另一种方法:

from operator import itemgetter
np.array(itemgetter(*a)(b_dict))

output: output:

[20., 20., 20., 20., 20., 20., 20., 20., 20., 20.]

Comparison :比较

#@kmundnic solution
def m1(a):
  def get_b(x):
    b_dict = {  1:10., 2:20., 3:30. }
    return b_dict[x]
  return np.fromiter(map(get_b, a),dtype=np.float)

#@bigbounty solution
def m2(a):
  b_dict = {  1:10., 2:20., 3:30. }
  return np.vectorize(b_dict.get)(a)

#@Ehsan solution
def m3(a):
  b_dict = {  1:10., 2:20., 3:30. }
  return np.array(itemgetter(*a)(b_dict))

#@Sun Bear solution
def m4(a):
  def get_b( a ):
    b_dict = {  1:10., 2:20., 3:30. }
    return b_dict[ a ]
  return np.array( [get_b(i) for i in a] )

in_ = [np.full( n, 2 ) for n in [10,100,1000,10000]]

For small dictionary , seems m2 is fastest at large inputs and m3 for smaller ones.对于小型字典,似乎m2在大输入时最快,而m3在较小输入时最快。

在此处输入图像描述

And for a larger dictionary :对于更大的字典

b_dict = dict(zip(np.arange(100),np.arange(100)))
in_ = [np.full(n,50) for n in [10,100,1000,10000]]

m3 is the fastest approach. m3是最快的方法。 You can choose based on your dictionary size and key array size.您可以根据您的字典大小和键数组大小进行选择。

在此处输入图像描述

I like to stress the value of @hpaulj comment to my question:我想强调@hpaulj对我的问题发表评论的价值:

Does b_dict have to be a dict? b_dict必须是字典吗? If you had an array, eg.如果你有一个数组,例如。 ref = np.array([0, 10,20,30]) you quickly select the values by index, ref[a] . ref = np.array([0, 10,20,30])你快速 select 索引值, ref[a] I would try to avoid dict when working with numpy.在使用 numpy 时,我会尽量避免使用 dict。

I found that using NumPy's indexing will lead to a few to several orders of magnitude faster in performance than when trying to work with a python dict .我发现使用 NumPy 的索引将导致性能比尝试使用 python dict时快几个到几个数量级。

Building on @Ehsan's solution , below is a script that makes such a comparison.基于@Ehsan 的解决方案,下面是一个进行这种比较的脚本。

import numpy as np
from operator import itemgetter
import timeit
import matplotlib.pyplot as plt


#@kmundnic solution
def m1(a):
    def get_b(x):
        b = {  1:10., 2:20., 3:30. }
        #b = dict( zip( np.arange(1,101),np.arange(10,1001,10) ) )
        return b[x]
    return np.fromiter(map(get_b, a),dtype=np.float)

#@bigbounty solution
def m2(a):
    b = {  1:10., 2:20., 3:30. }
    #b = dict( zip( np.arange(1,101),np.arange(10,1001,10) ) )
    return np.vectorize(b.get)(a)

#@Ehsan solution
def m3(a):
    b = {  1:10., 2:20., 3:30. }
    #b = dict( zip( np.arange(1,101),np.arange(10,1001,10) ) )
    return np.array(itemgetter(*a)(b))

#@Sun Bear solution
def m4(a):
    def get_b( a ):
        b = {  1:10., 2:20., 3:30. }
        #b = dict( zip( np.arange(1,101),np.arange(10,1001,10) ) )
        return b[ a ]
    return np.array( [get_b(i) for i in a] )

#@hpaulj solution
def m5(a):
    b = np.array([10, 20, 30])
    #b = np.arange(10,1001,10) 
    return b[a]

        
sizes=[10,100,1000,10000]
pm1 = []
pm2 = []
pm3 = []
pm4 = []
pm5 = []
for size in sizes:
    a = np.full( size, 2 )
    pm1.append( timeit.timeit( 'm1(a)', number=1000, globals=globals() ) )
    pm2.append( timeit.timeit( 'm2(a)', number=1000, globals=globals() ) )
    pm3.append( timeit.timeit( 'm3(a)', number=1000, globals=globals() ) )
    pm4.append( timeit.timeit( 'm4(a)', number=1000, globals=globals() ) )
    pm5.append( timeit.timeit( 'm5(a)', number=1000, globals=globals() ) )

print( 'm1 slower than m5 by :',np.array(pm1) / np.array(pm5) )
print( 'm2 slower than m5 by :',np.array(pm2) / np.array(pm5) )
print( 'm3 slower than m5 by :',np.array(pm3) / np.array(pm5) )
print( 'm4 slower than m5 by :',np.array(pm4) / np.array(pm5) )

fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot( sizes, pm1, label='m1' )
ax.plot( sizes, pm2, label='m2' )
ax.plot( sizes, pm3, label='m3' )
ax.plot( sizes, pm4, label='m4' )
ax.plot( sizes, pm5, label='m5' )
ax.grid( which='both' )
ax.set_xscale('log')
ax.set_yscale('log')
ax.legend()
ax.get_xaxis().set_label_text( label='len(a)', fontweight='bold' )
ax.get_yaxis().set_label_text( label='Runtime (sec)', fontweight='bold' )
plt.show()

Results:结果:

len(b) = 3:长度(b)= 3:

m1 slower than m5 by : [  4.22462367  29.79407905  85.03454097 339.2915358 ]
m2 slower than m5 by : [  8.64220685 11.57175871 13.76761749 46.1940683 ]
m3 slower than m5 by : [  3.25785432  21.63131578  54.71305704 220.15777696 ]
m4 slower than m5 by : [  4.60710166  30.93616607  91.8936744  371.00398273 ]

len(b) = 100:长度(b)= 100:

m1 slower than m5 by : [  218.98603678  1976.50128737  9697.76615006 17742.79151719 ]
m2 slower than m5 by : [  41.76535891  53.85600913 109.35129345 164.13075291 ]
m3 slower than m5 by : [  24.82715462  36.77830986  87.56253196 141.04493237 ]
m4 slower than m5 by : [  222.04184193  2001.72120836  9775.22464369 18431.00155305 ]

比较

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM