简体   繁体   中英

How to prepare, populate and explore an n-dimensional numpy array?

I have modeled a physical system with like 28 parameters. The simulation computes another list of about 10 output parameters. Now I need to explore the parameter space: some of the input parameters I keep constant, and some have several values. The input structure is like this:

input_params = {
    'input1': [0.3], # fixed
    'input2': [1.5, 4.5, 4], # variable param: [start, end, number_of_intermediate_values]
    'input3': [1200.0], # fixed
    'input4': [-0.1, -0.5, 10], # variable param: [start, end, number_of_intermediate_values]
    'input5': [1e-3], # fixed
}

The output of the simulation program is like this (for a combination of inputs):

output_params = {
    'output1': 3.9,
    'output2': -2.5,
    'output3': 100.0,
}

I would like to generate an n-dimensional array so that I can explore it afterwards with maximum flexibility. For the above example, it should be an array like this:

results = np.zeros(shape=(1,4,1,10,1,8))

where the first axis is for input1 (one value), the second axis for input2 (four values) and so on, and the last axis contains all the data [input1, input2, input3, input4, input5, output1, output2, output3] (5 + 3 = 8 values). For this example it would be a 4 x 10 x 8 = 320 values array shaped as described.

My question is: How can I generate this structure, and then populate it (iterate over each axis) without writing 28 nested for loops by hand?

Or maybe my data structures aren't right and there exists a better solution?

I am open to a solution using pandas (because I would like to be able to handle parameter names as strings). Or simple python dicts. Execution speed is not that important, since the bottleneck is the computing time of each simulation (it needs to reach a steady state) and I can afford spending a few milisecs between simulations.

I also need flexibility in how I choose which parameters are fixed and which are variable (and how many values they have).

To generate all the input patterns, you can use: pd.MultiIndex.from_product()

Here is the code:

import numpy as np
import pandas as pd

input_params = {
    'input1': [0.3], # fixed
    'input2': [1.5, 4.5, 4], # variable param: [start, end, number_of_intermediate_values]
    'input3': [1200.0], # fixed
    'input4': [-0.1, -0.5, 10], # variable param: [start, end, number_of_intermediate_values]
    'input5': [1e-3], # fixed
}

def expand_input(inputs):
    if len(inputs) == 1:
        return inputs
    return np.linspace(*inputs).tolist()

def sim(in_pars):
    "dummy simulation that returns three results"
    return np.min(in_pars), np.mean(in_pars), np.max(in_pars)

items = sorted(input_params.items())
keys = [item[0] for item in items]
inputs = [expand_input(item[1]) for item in items]

idx = pd.MultiIndex.from_product(inputs, names=keys)
df = pd.DataFrame(np.zeros((len(idx), 3)), columns="res1 res2 res3".split(), index=idx)

for key, row in df.iterrows():
    row[:] = sim(key)

print df

output:

                                           res1        res2  res3
input1 input2 input3 input4    input5                            
0.3    1.5    1200   -0.100000 0.001  -0.100000  240.340200  1200
                     -0.144444 0.001  -0.144444  240.331311  1200
                     -0.188889 0.001  -0.188889  240.322422  1200
                     -0.233333 0.001  -0.233333  240.313533  1200
                     -0.277778 0.001  -0.277778  240.304644  1200
                     -0.322222 0.001  -0.322222  240.295756  1200
                     -0.366667 0.001  -0.366667  240.286867  1200
                     -0.411111 0.001  -0.411111  240.277978  1200
                     -0.455556 0.001  -0.455556  240.269089  1200
                     -0.500000 0.001  -0.500000  240.260200  1200
       2.5    1200   -0.100000 0.001  -0.100000  240.540200  1200
                     -0.144444 0.001  -0.144444  240.531311  1200
                     -0.188889 0.001  -0.188889  240.522422  1200
                     -0.233333 0.001  -0.233333  240.513533  1200
                     -0.277778 0.001  -0.277778  240.504644  1200
                     -0.322222 0.001  -0.322222  240.495756  1200
                     -0.366667 0.001  -0.366667  240.486867  1200
                     -0.411111 0.001  -0.411111  240.477978  1200
                     -0.455556 0.001  -0.455556  240.469089  1200
                     -0.500000 0.001  -0.500000  240.460200  1200
       3.5    1200   -0.100000 0.001  -0.100000  240.740200  1200
                     -0.144444 0.001  -0.144444  240.731311  1200
                     -0.188889 0.001  -0.188889  240.722422  1200
                     -0.233333 0.001  -0.233333  240.713533  1200
                     -0.277778 0.001  -0.277778  240.704644  1200
                     -0.322222 0.001  -0.322222  240.695756  1200
                     -0.366667 0.001  -0.366667  240.686867  1200
                     -0.411111 0.001  -0.411111  240.677978  1200
                     -0.455556 0.001  -0.455556  240.669089  1200
                     -0.500000 0.001  -0.500000  240.660200  1200
       4.5    1200   -0.100000 0.001  -0.100000  240.940200  1200
                     -0.144444 0.001  -0.144444  240.931311  1200
                     -0.188889 0.001  -0.188889  240.922422  1200
                     -0.233333 0.001  -0.233333  240.913533  1200
                     -0.277778 0.001  -0.277778  240.904644  1200
                     -0.322222 0.001  -0.322222  240.895756  1200
                     -0.366667 0.001  -0.366667  240.886867  1200
                     -0.411111 0.001  -0.411111  240.877978  1200
                     -0.455556 0.001  -0.455556  240.869089  1200
                     -0.500000 0.001  -0.500000  240.860200  1200

If there are so many input patterns, I think it's not a good idea to save all the results in memory, you can save the results to file:

from itertools import product

with open("result.txt", "w") as f:
    for in_pars in product(*inputs):
        res = sim(in_pars)
        f.write(",".join(str(x) for x in in_pars + res))
        f.write("\n")

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM