简体   繁体   中英

How to find the average of arrays (an array column) on 0th axis in a PySpark dataframe?

I have a PySpark dataframe-

df = spark.createDataFrame([
    ("u1", [[1., 2., 3.], [1., 2., 0.], [1., 0., 0.]]),
    ("u2", [[1., 10., 0.]]),
    ("u3", [[1., 0., 3.], [10., 0., 0.]]),
    ],
    ['user_id', 'features'])

print(df.printSchema())
df.show(truncate=False)

Output-

root
 |-- user_id: string (nullable = true)
 |-- features: array (nullable = true)
 |    |-- element: array (containsNull = true)
 |    |    |-- element: double (containsNull = true)

None
+-------+---------------------------------------------------+
|user_id|features                                           |
+-------+---------------------------------------------------+
|u1     |[[1.0, 2.0, 3.0], [1.0, 2.0, 0.0], [1.0, 0.0, 0.0]]|
|u2     |[[1.0, 10.0, 0.0]]                                 |
|u3     |[[1.0, 0.0, 3.0], [10.0, 0.0, 0.0]]                    |
+-------+---------------------------------------------------+

I want to calculate the average of these arrays for every user on the 0th axis. The desired output would look like-

+-------+---------------------------------------------------+----------------+
|user_id|features                                           |avg_features    |
+-------+---------------------------------------------------+----------------+
|u1     |[[1.0, 2.0, 3.0], [1.0, 2.0, 0.0], [1.0, 0.0, 0.0]]|[1.0, 1.33, 1.0]|
|u2     |[[1.0, 10.0, 0.0]]                                 |[1.0, 10.0, 0.0]|
|u3     |[[1.0, 0.0, 3.0], [10.0, 0.0, 0.0]]                |[5.5, 0.0,  1.5]|
+-------+---------------------------------------------------+----------------+

How do I achieve this?

EDIT: a more scalable solution:

import pyspark.sql.functions as F

df2 = df.withColumn(
    'exploded_features', F.explode('features')
).select(
    'user_id', 'features', F.posexplode('exploded_features')
).groupBy(
    'user_id', 'features', 'pos'
).agg(
    F.mean('col')
).groupBy(
    'user_id', 'features'
).agg(
    F.array_sort(
        F.collect_list(
            F.array('pos', 'avg(col)')
        )
    ).alias('avg_features')
).withColumn(
    'avg_features',
    F.expr('transform(avg_features, x -> x[1])')
)

df2.show(truncate=False)
+-------+---------------------------------------------------+------------------------------+
|user_id|features                                           |avg_features                  |
+-------+---------------------------------------------------+------------------------------+
|u1     |[[1.0, 2.0, 3.0], [1.0, 2.0, 0.0], [1.0, 0.0, 0.0]]|[1.0, 1.3333333333333333, 1.0]|
|u2     |[[1.0, 10.0, 0.0]]                                 |[1.0, 10.0, 0.0]              |
|u3     |[[1.0, 0.0, 3.0], [10.0, 0.0, 0.0]]                |[5.5, 0.0, 1.5]               |
+-------+---------------------------------------------------+------------------------------+

Use aggregate and transform to operate on the arrays:

df2 = df.selectExpr(
    'user_id',
    'features',
    'array(
        aggregate(transform(features, x -> x[0]), cast(0 as double), (x, y) -> (x + y)) / size(features),
        aggregate(transform(features, x -> x[1]), cast(0 as double), (x, y) -> (x + y)) / size(features),
        aggregate(transform(features, x -> x[2]), cast(0 as double), (x, y) -> (x + y)) / size(features)
    ) as avg'
)

df2.show(truncate=False)
+-------+---------------------------------------------------+------------------------------+
|user_id|features                                           |avg                           |
+-------+---------------------------------------------------+------------------------------+
|u1     |[[1.0, 2.0, 3.0], [1.0, 2.0, 0.0], [1.0, 0.0, 0.0]]|[1.0, 1.3333333333333333, 1.0]|
|u2     |[[1.0, 10.0, 0.0]]                                 |[1.0, 10.0, 0.0]              |
|u3     |[[1.0, 0.0, 3.0], [10.0, 0.0, 0.0]]                |[5.5, 0.0, 1.5]               |
+-------+---------------------------------------------------+------------------------------+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM