简体   繁体   中英

How to get topN elements from this dict with list values?

For example:

from collections import defaultdict

tag_dict = {'a':[('mary', 0.99, 'f'), ('tom', 0.87), 'm'],
            'b':[('jack', 0.43, 'm')],
            'c':[('martin', 0.987, 'm'), ('alice', 0.973, 'f')]}

I want to get a new defaultdict which contains top 3 elements sorted by the float number in each element. So for the example above, the result would be:

top3_dict = {'a':[('mary', 0.99, 'f')],
             'c':[('martin', 0.987, 'm'), ('alice', 0.973, 'f')]}

Is there an efficient way to do this?

Edit:

My actual question is a little more complicated. There are cases in which the original dict size is less than 3, so in that case the correct way to specify the problem is that:

Get up to topN elements sorted by the float number, if the elements are less n, just get all elements sorted order.

You can flatten the structure out and store the outside key with the pairs. Then sort this, take the top three items and make the dictionary:

from collections import defaultdict

result = defaultdict(list)

sorted_items = sorted(
    [(k, item) for k,v in tag_dict.items() 
     for item in v], 
    key=lambda p: p[1][1], reverse=True)

for k, v in sorted_items[:3]:
    result[k].append(v)
    

Result will be:

defaultdict(list,
            {'a': [('mary', 0.99, 'f')],
             'c': [('martin', 0.987, 'm'), ('alice', 0.973, 'f')]})

Above sorted_items will look like:

[('a', ('mary', 0.99, 'f')),
 ('c', ('martin', 0.987, 'm')) 
 ...
]

at which point making a new dict is simple.

Edit:

My actual question is a little more complicated. There are cases in which the original dict size is less than 3, so in that case the correct way to specify the problem is that:

Get up to topN elements sorted by the float number, if the elements are less n, just get all elements sorted order.

Try the below

from collections import defaultdict
data = []
tag_dict = {'a':[('mary', 0.99, 'f'), ('tom', 0.87,'m') ],
            'b':[('jack', 0.43, 'm')],
            'c':[('martin', 0.987, 'm'), ('alice', 0.973, 'f')]}

for k,v in tag_dict.items():
    for entry in v:
        data.append((k,*entry))
data = sorted(data,key = lambda x: x[2],reverse=True)
result = defaultdict(list)
for i in range(0,3):
    result[data[i][0]].append(data[i][1:])
print(result)

output

defaultdict(<class 'list'>, {'a': [('mary', 0.99, 'f')], 'c': [('martin', 0.987, 'm'), ('alice', 0.973, 'f')]})

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM