简体   繁体   中英

How to pass slice into a function by reference

If I have

a = [1, 2, 3]

def foo (arr):
    for i in len (arr): arr [i] += 1

def bar (arr):
    foo (arr[:2])

bar (a)
print (a)

I want output as

>>> [2, 3, 3 ]

How do I go about this?

Motivation: I want a priority queue where I can freeze last N elements, ie pass only queue[:N] to heapq.heappush() . But, every time I pass a slice, to it, or to any function in general, it sends a copy of slice and not the actual list to the function, and so my list remains unchanged in the end.

Use a list comprehension and update the initial list using a full slice assignment with [:] :

def foo(arr):
   arr[:] = [x+1 for x in arr]

Trial:

>>> a = [1, 2, 3]
>>> def foo(arr):
...    arr[:] = [x+1 for x in arr]
...
>>> foo(a)
>>> a
[2, 3, 4]

This will work if a is a numpy array. By default, numpy slices refer to the same block of memory as the original array.

import numpy as np
a = np.array([1, 2, 3])

def foo(arr):
    for i in range(len(arr)): arr[i] += 1
    # or just arr += 1

def bar(arr):
    foo(arr[:2])

bar(a)
print(a)
# [2, 3, 3 ]

Slicing the list will create a new list with the sliced contents , as such arr[:2] loses the reference to the original a .

Apart from that, iterating as you did won't change the list at all, it just changes an item and disregards the value.

If you want to alter specific parts of the list, you must carry with you a reference to the original list. For example, keep arr , iterate through a slice of it with enumerate(arr[:2]) and then alter arr :

a = [1, 2, 3]

def foo(arr):
    for i, item in enumerate(arr[:2]): 
        arr[i] = item + 1

def bar(arr):
    foo(arr)

bar(a)
print(a)

Printing now yields [2, 3, 3] , removing the slice in enumerate results in [2, 3, 4] . Of course, bar here serves no purpose, you could drop it and just keep foo and call foo(a) immediately.

To be honest instead of going for slice, I would just pass the indexes;

a = [1, 2, 3]
def foo(array, start, stop, jmp= 1):
    for idx in range(start, stop + 1, jmp):
        array[idx] += 1

def bar(array):
    foo(array, 1, 2)
bar(a)
print(a)
[1, 3, 4]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM