简体   繁体   中英

pySpark - Row means excluding zeros

I have a spark data frame comprised of > 1200 columns and need to take the average of sets of columns, by row, excluding zero values.

For example, for the following data frame

id | col1 | col2 | col3
1  |  0   |  2   |  3
2  |  4   |  2   |  3
3  |  1   |  0   |  3
4  |  0   |  0   |  0

I'd expect:

id | mean 
1  |  2.5  
2  |  3  
3  |  2
4  |  0  

In Python I'm aware that is possible to achieve something like this using an strategy similar to

data[data == 0] = np.nan
means = np.nanmean(data[:, 1:], axis=1)

But I'm not sure how to do the same in pySpark.

You can use something like below

>>> import pyspark.sql.functions as F
>>> 
>>> df.show()
+---+----+----+----+
| id|col1|col2|col3|
+---+----+----+----+
|  1|   0|   2|   3|
|  2|   4|   2|   3|
|  3|   1|   0|   3|
|  4|   0|   0|   0|
+---+----+----+----+

>>> cols = [i for i in df.columns if i != 'id']
>>> df = df.withColumn('mean',\
...     sum([df[i] for i in cols])/ \
...     sum([F.when(df[i]>0,1).otherwise(0) for i in cols])). \
...     fillna(0,'mean')
>>> 
>>> df.show()
+---+----+----+----+----+
| id|col1|col2|col3|mean|
+---+----+----+----+----+
|  1|   0|   2|   3| 2.5|
|  2|   4|   2|   3| 3.0|
|  3|   1|   0|   3| 2.0|
|  4|   0|   0|   0| 0.0|
+---+----+----+----+----+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM