I have a spark data frame comprised of > 1200 columns and need to take the average of sets of columns, by row, excluding zero values.
For example, for the following data frame
id | col1 | col2 | col3
1 | 0 | 2 | 3
2 | 4 | 2 | 3
3 | 1 | 0 | 3
4 | 0 | 0 | 0
I'd expect:
id | mean
1 | 2.5
2 | 3
3 | 2
4 | 0
In Python I'm aware that is possible to achieve something like this using an strategy similar to
data[data == 0] = np.nan
means = np.nanmean(data[:, 1:], axis=1)
But I'm not sure how to do the same in pySpark.
You can use something like below
>>> import pyspark.sql.functions as F
>>>
>>> df.show()
+---+----+----+----+
| id|col1|col2|col3|
+---+----+----+----+
| 1| 0| 2| 3|
| 2| 4| 2| 3|
| 3| 1| 0| 3|
| 4| 0| 0| 0|
+---+----+----+----+
>>> cols = [i for i in df.columns if i != 'id']
>>> df = df.withColumn('mean',\
... sum([df[i] for i in cols])/ \
... sum([F.when(df[i]>0,1).otherwise(0) for i in cols])). \
... fillna(0,'mean')
>>>
>>> df.show()
+---+----+----+----+----+
| id|col1|col2|col3|mean|
+---+----+----+----+----+
| 1| 0| 2| 3| 2.5|
| 2| 4| 2| 3| 3.0|
| 3| 1| 0| 3| 2.0|
| 4| 0| 0| 0| 0.0|
+---+----+----+----+----+
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.