简体   繁体   English

pyspark:grouby然后获得每组的最大值

[英]pyspark: grouby and then get max value of each group

I would like to group by a value and then find the max value in each group using PySpark. 我想按值分组,然后使用PySpark在每个组中找到最大值。 I have the following code but now I am bit stuck on how to extract the max value. 我有以下代码,但现在我有点不知道如何提取最大值。

# some file contains tuples ('user', 'item', 'occurrences')
data_file = sc.textData('file:///some_file.txt')
# Create the triplet so I index stuff
data_file = data_file.map(lambda l: l.split()).map(lambda l: (l[0], l[1], float(l[2])))
# Group by the user i.e. r[0]
grouped = data_file.groupBy(lambda r: r[0])
# Here is where I am stuck 
group_list = grouped.map(lambda x: (list(x[1]))) #?

Returns something like: 返回类似于:

[[(u'u1', u's1', 20), (u'u1', u's2', 5)], [(u'u2', u's3', 5), (u'u2', u's2', 10)]]

I want to find max 'occurrence' for each user now. 我想现在为每个用户找到最大'发生'。 The final result after doing the max would result in a RDD that looked like this: 执行max后的最终结果将导致RDD看起来像这样:

[[(u'u1', u's1', 20)], [(u'u2', u's2', 10)]]

Where only the max dataset would remain for each of the users in the file. 只保留文件中每个用户的最大数据集。 In other words, I want to change the value of the RDD to contain only a single triplet the each users max occurrences. 换句话说,我想将RDD的更改为仅包含每个用户最多出现的一个三元组。

There is no need for groupBy here. 这里不需要groupBy Simple reduceByKey would do just fine and most of the time will be more efficient: 简单的reduceByKey就可以了,并且大部分时间都会更有效:

data_file = sc.parallelize([
   (u'u1', u's1', 20), (u'u1', u's2', 5),
   (u'u2', u's3', 5), (u'u2', u's2', 10)])

max_by_group = (data_file
  .map(lambda x: (x[0], x))  # Convert to PairwiseRD
  # Take maximum of the passed arguments by the last element (key)
  # equivalent to:
  # lambda x, y: x if x[-1] > y[-1] else y
  .reduceByKey(lambda x1, x2: max(x1, x2, key=lambda x: x[-1])) 
  .values()) # Drop keys

max_by_group.collect()
## [('u2', 's2', 10), ('u1', 's1', 20)]

I think I found the solution: 我想我找到了解决方案:

from pyspark import SparkContext, SparkConf

def reduce_by_max(rdd):
    """
    Helper function to find the max value in a list of values i.e. triplets. 
    """
    max_val = rdd[0][2]
    the_index = 0

    for idx, val in enumerate(rdd):
        if val[2] > max_val:
            max_val = val[2]
            the_index = idx

    return rdd[the_index]

conf = SparkConf() \
    .setAppName("Collaborative Filter") \
    .set("spark.executor.memory", "5g")
sc = SparkContext(conf=conf)

# some file contains tuples ('user', 'item', 'occurrences')
data_file = sc.textData('file:///some_file.txt')

# Create the triplet so I can index stuff
data_file = data_file.map(lambda l: l.split()).map(lambda l: (l[0], l[1], float(l[2])))

# Group by the user i.e. r[0]
grouped = data_file.groupBy(lambda r: r[0])

# Get the values as a list
group_list = grouped.map(lambda x: (list(x[1]))) 

# Get the max value for each user. 
max_list = group_list.map(reduce_by_max).collect()

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM