简体   繁体   中英

How to pass a custom equality_check function into perfplot

I'm working with the perfplot library to compare the performance of three functions f1 , f2 and f3 . The functions are supposed to return the same values, so I want to do equality checks. However, all other examples of perfplot I can find on the inte.net use pd.DataFrame.equals or np.allclose as an equality checker but these don't work for my specific case.

For example, np.allclose wouldn't work if the functions return a list of numpy arrays of different lengths.

import perfplot
import numpy as np

def f1(rng):
    return [np.array(range(i)) for i in rng]

def f2(rng):
    return [np.array(list(rng)[:i]) for i in range(len(rng))]

def f3(rng):
    return [np.array([*range(i)]) for i in rng]


perfplot.show(
    kernels=[f1, f2, f3],
    n_range=[10**k for k in range(4)],
    setup=lambda n: range(n),
    equality_check=np.allclose   # <-- doesn't work; neither does pd.DataFrame.equals
)

How do I pass a function that is different from the aforementioned functions?

If we inspect the source code , the way the equality check works is that it takes the output of the first function passed to kernels as reference and compares it to the output of the subsequent functions passed to kernels in a loop.

For some reason, the equality check is different depending on if the first function in kernels returns a tuple or not.

If the first function in kernels doesn't returns a tuple, it simply calls the function passed to equality_check argument to perform the check. The equality check function takes two arguments and can do whatever. For example, it can only check if the lengths are equal and call it a day (ie pass the following lambda to equality_check : lambda x,y: len(x) == len(y) ).

For the example in the question, a function that loops over pairs of elements and checks for inequality works.

def equality_check(x, y):
    for i, j in zip(x, y):
        if not np.allclose(i, j):
            return False
    return True

perfplot.show(
    kernels=[f1, f2, f3],
    n_range=[10**k for k in range(4)],
    setup=lambda n: range(n),
    equality_check=equality_check
)

In fact, all() also works.

perfplot.show(
    kernels=[f1, f2, f3],
    n_range=[10**k for k in range(4)],
    setup=lambda n: range(1, n+1),
    equality_check=lambda x,y: all([x,y])
)

The following code copied from the source code is the snippet that implements the equality check:

for k, kernel in enumerate(self.kernels):

    val = kernel(*data)

    if self.equality_check:
        if k == 0:
            reference = val
        else:
            try:
                if isinstance(reference, tuple):
                    assert isinstance(val, tuple)
                    assert len(reference) == len(val)
                    is_equal = True
                    for r, v in zip(reference, val):
                        if not self.equality_check(r, v):
                            is_equal = False
                            break
                else:
                    is_equal = self.equality_check(reference, val)
            except TypeError:
                raise PerfplotError(
                    "Error in equality_check. "
                    + "Try setting equality_check=None."
                )
            else:
                if not is_equal:
                    raise PerfplotError(
                        "Equality check failure.\n"
                        + f"{self.labels[0]}:\n"
                        + f"{reference}:\n\n"
                        + f"{self.labels[k]}:\n"
                        + f"{val}:\n"
                    )

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM