I'm trying to quickly check how many items in a list are below a series of thresholds, similar to doing what's described here but a lot of times. The point of this is to do some diagnostics on a machine learning model that are a little more in depth than what is built in to sci-kit learn (ROC curves, etc.).
Imagine preds
is a list of predictions (probabilities between 0 and 1). In reality, I will have over 1 million of them, which is why I'm trying to speed this up.
This creates some fake scores, normally distributed between 0 and 1.
fake_preds = [np.random.normal(0, 1) for i in range(1000)]
fake_preds = [(pred + np.abs(min(fake_preds)))/max(fake_preds + np.abs(min(fake_preds))) for pred in fake_preds]
Now, the way I am doing this is to loop through 100 threshold levels and check how many predictions are lower at any given threshold:
thresholds = [round(n,2) for n in np.arange(0.01, 1.0, 0.01)]
thresh_cov = [sum(fake_preds < thresh) for thresh in thresholds]
This takes about 1.5 secs for 10k (less time than generating the fake predictions) but you can imagine it takes a lot longer with a lot more predictions. And I have to do this a few thousand times to compare a bunch of different models.
Any thoughts on a way to make that second code block faster? I'm thinking there must be a way to order the predictions to make it easier for the computer to check the thresholds (similar to indexing in SQL-like scenario) but I can't figure out any other way than sum(fake_preds < thresh)
to check them, and that doesn't take advantage of any indexing or ordering.
Thanks in advance for the help!
One way would be to use numpy.histogram
.
thresh_cov = np.histogram(fake_preds, len(thresholds))[0].cumsum()
From timeit
, I'm getting:
%timeit my_cov = np.histogram(fake_preds, len(thresholds))[0].cumsum()
169 µs ± 6.51 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit thresh_cov = [sum(fake_preds < thresh) for thresh in thresholds]
172 ms ± 1.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Method #1
You can sort predictions
array and then use searchsorted
or np.digitize
, like so -
np.searchsorted(np.sort(fake_preds), thresholds, 'right')
np.digitize(thresholds, np.sort(fake_preds))
If you don't mind mutating predictions
array, sort in-place with : fake_preds.sort()
and then use fake_preds
in place of np.sort(fake_preds)
. This should be much more performant as we would be avoiding the use of any extra memory there.
Method #2
Now, with the thresholds being 100
from 0
to 1
, those thresholds would be multiples of 0.01
. Thus, we can simply digitize with a scaling up of 100
for each of them and converting them to ints
, which could be pretty straight-forwardly fed as bins
to np.bincount
. Then, to get or desired result, use cumsum
, like so -
np.bincount((fake_preds*100).astype(int),minlength=99)[:99].cumsum()
Approaches -
def searchsorted_app(fake_preds, thresholds):
return np.searchsorted(np.sort(fake_preds), thresholds, 'right')
def digitize_app(fake_preds, thresholds):
return np.digitize(thresholds, np.sort(fake_preds) )
def bincount_app(fake_preds, thresholds):
return np.bincount((fake_preds*100).astype(int),minlength=99)[:99].cumsum()
Runtime test and verification on 10000
elements -
In [210]: np.random.seed(0)
...: fake_preds = np.random.rand(10000)
...: thresholds = [round(n,2) for n in np.arange(0.01, 1.0, 0.01)]
...: thresh_cov = [sum(fake_preds < thresh) for thresh in thresholds]
...:
In [211]: print np.allclose(thresh_cov, searchsorted_app(fake_preds, thresholds))
...: print np.allclose(thresh_cov, digitize_app(fake_preds, thresholds))
...: print np.allclose(thresh_cov, bincount_app(fake_preds, thresholds))
...:
True
True
True
In [214]: %timeit [sum(fake_preds < thresh) for thresh in thresholds]
1 loop, best of 3: 1.43 s per loop
In [215]: %timeit searchsorted_app(fake_preds, thresholds)
...: %timeit digitize_app(fake_preds, thresholds)
...: %timeit bincount_app(fake_preds, thresholds)
...:
1000 loops, best of 3: 528 µs per loop
1000 loops, best of 3: 535 µs per loop
10000 loops, best of 3: 24.9 µs per loop
That's a 2,700x+
speedup for searchsorted
and 57,000x+
for bincount
one! With larger datasets, the gap between bincount
and searchsorted
one is bound to increase, as bincount
doesn't need to sort.
You can reshape thresholds
here to enable broadcasting. First, here a few possible changes to your creation of fake_preds
and thresholds
that get rid of loops.
np.random.seed(123)
fake_preds = np.random.normal(size=1000)
fake_preds = (fake_preds + np.abs(fake_preds.min())) \
/ (np.max(fake_preds + np.abs((fake_preds.min()))))
thresholds = np.linspace(.01, 1, 100)
Then what you want to do is accomplishable in 1 line:
print(np.sum(np.less(fake_preds, np.tile(thresholds, (1000,1)).T), axis=1))
[ 2 2 2 2 2 2 5 5 6 7 7 11 11 11 15 18 21 26
28 34 40 48 54 63 71 77 90 100 114 129 143 165 176 191 206 222
240 268 288 312 329 361 392 417 444 479 503 532 560 598 615 648 671 696
710 726 747 768 787 800 818 840 860 877 891 902 912 919 928 942 947 960
965 970 978 981 986 987 988 991 993 994 995 995 995 997 997 997 998 998
999 999 999 999 999 999 999 999 999 999]
Walkthrough:
fake_preds
has shape (1000,1). You need to manipulate thresholds
into a shape that is compatible for broadcasting with this. (See general broadcasting rules .)
A broadcastable second shape would be
print(np.tile(thresholds, (1000,1)).T.shape)
# (100, 1000)
Option 1:
from scipy.stats import percentileofscore
thresh_cov = [percentileofscore (fake_preds, thresh) for thresh in thresholds]
Option 2: same as above, but sort the list first
Option 3: Insert your thresholds into the list, sort the list, find the indices of your thresholds. Note that if you have a quicksort algorithm, you can optimize it for your purposes by making your thresholds the pivots and terminating the sort once you've partitioned everything according to the thresholds.
Option 4: Building on the above: Put your thresholds in a binary tree, then for each item in the list, compare it to the thresholds in a binary search. You can either do it item by item, or split the list into subsets at each step.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.