简体   繁体   English

数组的副本在 function 中被覆盖

[英]copy of array gets overwritten in function

I am trying to create an array np.zeros((3, 3)) outside a function and use it inside a function over and over again.我正在尝试在 function 之外创建一个数组np.zeros((3, 3))并一遍又一遍地在 function 中使用它。 The reason for that is numba's cuda implementation, which does not support array creation inside functions that are to be run on a gpu .原因是numba's cuda实现,它不支持在gpu上运行的函数内部array creation So I create aforementioned array ar_ref , and pass it as argument to function .所以我创建了上述数组ar_ref ,并将其作为参数传递给function ar creates a copy of ar_ref (this is supposed to be used as "fresh" np.zeros((3, 3)) copy). ar创建ar_ref的副本(这应该用作“新鲜” np.zeros((3, 3))副本)。 Then I perform some changes to ar and return it.然后我对ar进行一些更改并返回它。 But in the process ar_ref gets overwritten inside the function by the last iteration of ar .但是在这个过程中ar_ref在 function 的最后一次迭代中被ar覆盖。 How do I start every new iteration of the function with ar = np.zeros((3, 3)) without having to call np.zeros inside the function ?如何使用ar = np.zeros((3, 3))开始 function 的每次新迭代,而不必在np.zeros内调用function

import numpy as np

def function(ar_ref=None):
    for n in range(3):
        print(n)
        ar = ar_ref
        print(ar)
        for i in range(3):
            ar[i] = 1
        print(ar)
    return ar

ar_ref = np.zeros((3, 3))
function(ar_ref=ar_ref)

Output: Output:

0
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
1
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
2
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]

simple assignment will only assign pointer, so when you change ar , ar_ref changes too.简单的分配只会分配指针,所以当你改变ar时, ar_ref改变。 try to use shallow copy for this issue尝试对这个问题使用浅拷贝

import numpy as np
import copy 

def function(ar_ref=None):
    for n in range(3):
        print(n)
        ar = copy.copy(ar_ref)
        print(ar)
        for i in range(3):
            ar[i] = 1
        print(ar)
    return ar

ar_ref = np.zeros((3, 3))
function(ar_ref=ar_ref)

output: output:

0
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
1
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
2
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM