I am working on a pipeline that includes obtaining predictions from a machine learning model and I'm trying to use ray to speed it up. The inputs can get repetitive so I'd like to share a cache in the function such that all the workers of this remote function can a share access to the cache and search it and obtain values. Something like the below
@ray.remote
def f(x):
# create inputs from x
# do work
unknown_y1 = []
obtained_y1 = []
for index, y in enumerate(y1):
key = '|'.join([str(x) for x in y.values()])
if key in cached:
obtained_y1.append(cached[key])
else:
obtained_y1.append(np.inf)
unknown_y1.append(promo)
unknown_y2 = []
obtained_y2 = []
for index, y in enumerate(y2):
key = '|'.join([str(x) for x in y.values()])
if key in cached:
obtained_y2.append(cached[key])
else:
obtained_y2.append(np.inf)
unknown_y2.append(baseline)
known_y1, known_y2 = predictor.predict(unknown_y1,unknown_y2)
unknown_index = 0
for index in range(len(y1)):
if(obtained_y1[index] == np.inf):
obtained_y1[index] = known_y1[unknown_index]
key = '|'.join([str(x) for x in y1[index].values()])
if not(key in cached):
cached[key] = obtained_y1[index]
unknown_index = unknown_index+1
unknown_index = 0
for index in range(len(y2)):
if(obtained_y2[index] == np.inf):
obtained_y2[index] = known_y2[unknown_index]
key = '|'.join([str(x) for x in y2[index].values()])
if not(key in cached):
cached[key] = obtained_y2[index]
unknown_index = unknown_index+1
I've tried creating a global dictionary by adding global cached;cached=dict()
at the top of my script but it seems like that variable is a different version across workers and does not share the data. Previously I was doing this with dogpile.cache.redis
but the region will not be serializable since it uses a thread lock. I've also tried creating a dict and putting it in ray's object store using ray.put(cached)
but I think I read somewhere that ray cannot share dictionaries in memory
I am currently trying to return the cache from each worker and merge them in the main and then put them in object store again. Is there a better way of sharing a dictionary/cache between ray workers?
Unfortunately, you did not create a minimal, reproducible example so I cannot see how you are doing your multiprocessing. For the sake of argument, I will assume you are using the Pool
class from the multiprocessing
module ( concurrent.futures.ProcessPoolExecutor
as a similar facility). Then you want to use a managed , sharable dict
as follows:
from multiprocessing import Pool, Manager
def init_pool(the_cache):
# initialize each process in the pool with the following global variable:
global cached
cached = the_cache
def main():
with Manager() as manager:
cached = manager.dict()
with Pool(initializer=init_pool, initargs=(cached,)) as pool:
... # code that creates tasks
# required by Windows:
if __name__ == '__main__':
main()
This creates in dictionary with variable cached
a reference to a proxy for this dictionary. So all dictionary accesses become essentially more akin to remote procedure calls and therefore execute much more slowly than would a "normal" dictionary access. Just be aware...
If there is some other mechanism for creating workers (the decorator @ray.remote
?), the cached
variable can instead be passed as an argument to function f
.
You may be interested in this question/answer about writing a function cache for Ray. Implementing cache for Ray actor function
You have the right idea, but I think the key detail you're missing is that you should keep global state in an actor or object store (if it's immutable) with Ray.
In your case, it looks like you are trying to cache parts of your remote function, not the whole thing. You might want something that looks like this.
Here's a simplified version of how you may consider writing your function.
@ray.remote
class Cache:
def __init__(self):
self.cache = {}
def put(self, x, y):
self.cache[x] = y
def get(self, x):
return self.cache.get(x)
global_cache = Cache.remote()
@ray.remote
def f(x):
all_inputs = list(range(x)) # A simplified set of generated inputs based on x
obtained_output = ray.get([global_cache.get(i) for i in all_inputs])
unknown_indices = []
for i, output in enumerate(obtained_output):
if output is None:
unknown_inputs.append(i)
# Now go through and calculate all the unknown inputs
for i in unknown_inputs:
output = predict(all_inputs[i]) # calculate the output
global_cache.put.remote(output) # Cache it so it's available next time
obtained_output[i] = output
return obtained_output
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.