I am running a program in pycharm on a linux server which uses multiprocessing.Pool().map
for increased performance.
The code looks something like this:
import multiprocessing
from functools import partial
for episode in episodes:
with multiprocessing.Pool() as mpool:
func_part = partial(worker_function)
mpool.map(func_part, range(step))
The weird thing is that it runs perfectly fine on my Windows 10 Laptop but as soon as I try to run it on a linux server the program gets stuck at the exact last Process measurement count 241/242
, so right before proceeding to the next iteration of the loop eg the next episode.
No error message given. I am running pycharm on both machines. The Step
layer is where I placed the multiprocessing.Pool().map
function.
Edit:
I've added mpool.close()
and mpool.join()
but it does seem to have no effect:
import multiprocessing
from functools import partial
for episode in episodes:
with multiprocessing.Pool() as mpool:
func_part = partial(worker_function)
mpool.map(func_part, range(step))
mpool.close()
mpool.join()
It still gets stuck at the last process.
Edit2:
This is the worker function:
def worker_func(steplength, episode, episodes, env, agent, state, log_data_qvalues, log_data, steps):
env.time_ = step
action = agent.act(state, env) # given the state, the agent acts (eps-greedy) either by choosing randomly or relying on its own prediction (weights are considered here to sum up the q-values of all objectives)
next_state, reward = env.steplength(action, state) # given the action, the environment gives back the next_state and the reward for the transaction for all objectives seperately
agent.remember(state, action, reward, next_state, env.future_reward) # agent puts the experience in his memory
q_values = agent.model.predict(np.reshape(state, [1, env.state_size])) # This part is not necessary for the framework, but lets the agent predict every time_ to
start = 0 # to store the development of the prediction and to investigate the development of the Q-values
machine_start = 0
for counter, machine in enumerate(env.list_of_machines):
liste = [episode, steplength, state[counter]]
q_values_objectives = []
for objective in range(1, env.number_of_objectives + 1):
liste.append(objective)
liste.append(q_values[0][start:machine.actions + start])
start = int(agent.action_size / env.number_of_objectives) + start
log_data_qvalues.append(liste)
machine_start += machine.actions
start = machine_start
state = next_state
steps.append(state)
env.current_step += 1
if len(agent.memory) > agent.batch_size: # If the agent has collected more than batch_size-experience, the networks of the agents are starting
agent.replay(env) # to be trained, with the replay function, batch-size- samples from the memory of the agents are selected
agent.update_target_model() # the Q-target is updated after one batch-training
if steplength == env.steplength-2: # for plotting the process during training
#agent.update_target_model()
print(f'Episode: {episode + 1}/{episodes} Score: {steplength} e: {agent.epsilon:.5}')
log_data.append([episode, agent.epsilon])
As you can see it uses several classes to pass attributes. I don't know how I would reproduce it. I am still experimenting on where the process gets stuck exactly. The worker function communicates with the env
and the agent
class and passes information that is required to train a neural network. The agent
class controls the learning process while the env
class simulates the environment the network has control over.
step is an integer variable:
step = 12
Are you calling
mpool.close()
mpool.join()
at the end?
EDIT
The problem is not w/ multiprocessing
but with the measurement count
part. According to the screenshot, the pool map successfully ends w/ step 11 ( range(12)
starts at 0). measurement count
is nowhere to be seen in the provided snippets to try debugging that part.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.