简体   繁体   English

我不明白为什么这个 MPI 代码中的 while 循环没有中断

[英]I can't figure out why the while loop in this MPI code doesn't break

I'm doing a parallelization exercise using mpi4py where 2 dice are thrown a defined number of times (divided by the processes, ie, npp ) and the dots are counted.我正在使用 mpi4py 进行并行化练习,其中 2 个骰子被抛出定义的次数(除以进程,即npp )并计算点数。 The results are stored in a dictionary, the mean deviation is calculated and until the condition of mean_dev being less than 0.001 the number of throws is doubled.结果存储在字典中,计算平均偏差,直到mean_dev小于 0.001 的条件下,抛出次数加倍。

All of this works as expected, the problem is that the code doesn't quit.所有这一切都按预期工作,问题是代码没有退出。 The condition is met, there's no more outputs, but the code hangs.条件满足,没有更多输出,但代码挂起。

from ctypes.wintypes import SIZE
from dice import * #This is just a class that creates the dictionaries 
from random import randint
import matplotlib.pyplot as plt
import numpy as np
from mpi4py import MPI
from math import sqrt

def simulation(f_events, f_sides, f_n_dice):
    f_X = dice(sides, n_dice).myDice() #Nested dictionary composed of dices (last dict stores the sum)
    for j in range(f_events): #for loop to handle all the dice throwings aka events
        n = [] #List to store index respective to number on each dice
        for i in range(1, f_n_dice+1): #for cycle for each dice
            k = randint(1, f_sides) #Random number
            n.append(k)
            f_X[i][k] += 1 #The index (k) related to each throw is increased for the dice (i)
        sum_throw = sum(n) #Sum of the last throw
        f_X[f_n_dice+1][sum_throw] += 1 #Sum dictionary "increases" the index respective to the sum of the last throw
    return f_X

npp = int(4)//4 #Number of events divided by the number of processes
sides = 6 #Number of sides per dice
n_dice = 2 #Number of dices

comm = MPI.COMM_WORLD #Communicator to handle point-to-point communication
rank = comm.Get_rank() #Hierarchy of processes
size = comm.Get_size() #Number of processes

#-------------------- Parallelization portion of the code --------------------#

seq = (2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
AUX = dict.fromkeys(seq, 0)
mean_dev = 1
while True:
    msg = comm.bcast(npp, root = 0)
    print("---> msg: ", msg, " for rank ", rank)
    print("The mean dev for %d" %rank + " is: ", mean_dev)

    D = simulation(npp, sides, n_dice)
    
    Dp = comm.gather(D, root = 0)
    print("This is Dp: ", Dp)
    
    summ = 0
    prob = [1/36, 2/36, 3/36, 4/36, 5/36, 6/36, 5/36, 4/36, 3/36, 2/36, 1/36]

    if rank==0:
        for p in range(0, size): 
                for n in range(dice().min, dice().max+1): #Range from minimum sum possible to the maximum sum possible depending on the number of dices used
                    AUX[n] += Dp[p][n_dice+1][n] #Adds the new data to the final sum dictionary 
                                                                #of the previously initiated nested dictionary
                print(Dp[p][n_dice+1])
    
        print("The final dictionary is: ", AUX, sum(AUX[j] for j in AUX))

        for i in range(dice().min, dice().max+1):
            exp = (prob[i-2])*(sum(AUX[j] for j in AUX))
            x = (AUX[i]-exp)/exp
            summ = summ + pow(x, 2)

        mean_dev = (1/11)*sqrt(summ)
        print("The deviation for {} is {}.".format(sum(AUX[j] for j in AUX), mean_dev))

    if mean_dev > 0.001:
        npp = 2*npp
        # new_msg = comm.bcast(npp, root = 0)
        # print("---> new_msg: ", new_msg, " for rank ", rank)
    else:
        break
        

I'm stumped on this one.我被这个难住了。 Thanks in advance for any input!在此先感谢您的任何输入!


The new code with the solution proposed by @victor-eijkhout:带有@victor-eijkhout 提出的解决方案的新代码:

from ctypes.wintypes import SIZE
from dice import *
from random import randint
import matplotlib.pyplot as plt
import numpy as np
from mpi4py import MPI
from math import sqrt

def simulation(f_events, f_sides, f_n_dice):
    f_X = dice(sides, n_dice).myDice() #Nested dictionary composed of dices (last dict stores the sum)
    for j in range(f_events): #for loop to handle all the dice throwings aka events
        n = [] #List to store index respective to number on each dice
        for i in range(1, f_n_dice+1): #for cycle for each dice
            k = randint(1, f_sides) #Random number
            n.append(k)
            f_X[i][k] += 1 #The index (k) related to each throw is increased for the dice (i)
        sum_throw = sum(n) #Sum of the last throw
        f_X[f_n_dice+1][sum_throw] += 1 #Sum dictionary "increases" the index respective to the sum of the last throw
    return f_X

npp = int(4)//4 #Number of events divided by the number of processes
sides = 6 #Number of sides per dice
n_dice = 2 #Number of dices

comm = MPI.COMM_WORLD #Communicator to handle point-to-point communication
rank = comm.Get_rank() #Hierarchy of processes
size = comm.Get_size() #Number of processes

#-------------------- Parallelization portion of the code --------------------#

seq = (2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
AUX = dict.fromkeys(seq, 0)
mean_dev = 1
while True:
    msg = comm.bcast(npp, root = 0)
    #print("---> msg: ", msg, " for rank ", rank)
    
    D = simulation(npp, sides, n_dice)
        
    Dp = comm.gather(D, root = 0)
    #if Dp != None: print("This is Dp: ", Dp)

    
    #print("The mean dev for %d" %rank + " is: ", mean_dev)

    if rank==0:
        
        summ = 0
        prob = [1/36, 2/36, 3/36, 4/36, 5/36, 6/36, 5/36, 4/36, 3/36, 2/36, 1/36]

        for p in range(0, size): 
                for n in range(dice().min, dice().max+1): #Range from minimum sum possible to the maximum sum possible depending on the number of dices used
                    AUX[n] += Dp[p][n_dice+1][n] #Adds the new data to the final sum dictionary 
                                                                #of the previously initiated nested dictionary
                print(Dp[p][n_dice+1])
    
        print("The final dictionary is: ", AUX, sum(AUX[j] for j in AUX))

        for i in range(dice().min, dice().max+1):
            exp = (prob[i-2])*(sum(AUX[j] for j in AUX))
            x = (AUX[i]-exp)/exp
            summ = summ + pow(x, 2)

        mean_dev = (1/11)*sqrt(summ)
        print("The deviation for {} is {}.".format(sum(AUX[j] for j in AUX), mean_dev))

    #new_mean_dev = comm.gather(mean_dev, root = 0)
    new_mean_dev = comm.bcast(mean_dev, root = 0)
    print("---> msg2: ", new_mean_dev, " for rank ", rank)

    if new_mean_dev < 0.001:
        break
        # new_msg = comm.bcast(npp, root = 0)
        # print("---> new_msg: ", new_msg, " for rank ", rank)
        
    else:
        npp = 2*npp
        print("The new npp is: ", npp)

You are computing the mean deviation only on process zero, so that process will exit.您仅在过程零上计算平均偏差,因此该过程将退出。 However, the other processes do not get the value and so they never quit.但是,其他进程没有获得该值,因此它们永远不会退出。 You should broadcast the value after you compute it.您应该在计算后广播该值。

You are breaking out of your if statement.你正在打破你的 if 声明。 Just replace while True: with while mean_dev > 0.001: and you should be good.只需将 while True:替换为while mean_dev > 0.001:就可以了。 You can also just do an assignment at the end rather than wrapping it in the if .您也可以只在最后做一个赋值,而不是将它包装在if中。

If that doesn't work it simply means mean_dev is always greater than 0.001.如果这不起作用,则意味着mean_dev始终大于 0.001。 You calculate mean_dev as (1/11)*sqrt(sum …) .您将mean_dev计算为(1/11)*sqrt(sum …) Not under the whole algorithm, if the minimum sum of 2 dice is 2, then mean_dev will not drop below 0.14 or so.不是整个算法下,如果2个骰子的和最小为2,那么mean_dev不会降到0.14左右以下。 Try putting in a print statement and print out mean_dev each time through the loop and see if it's working as expected.尝试放入打印语句并在每次循环中打印出mean_dev ,看看它是否按预期工作。 Should you be dividing mean_dev by npp each time or something like that?您应该每次将mean_dev除以npp还是类似的东西?

As a general rule, these kinds of problems where one is iterating to find a closer approximation generally terminate when the change in the estimate becomes very small.作为一般规则,当估计值的变化变得非常小时,这类通过迭代寻找更接近近似值的问题通常会终止。 Should you be stopping when the change in mean_dev is less than 0.001?当 mean_dev 的变化小于 0.001 时是否应该停止? You would need to do something like abs(last_mean_dev-mean_dev)<0.001.您需要执行类似 abs(last_mean_dev-mean_dev)<0.001 的操作。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM