简体   繁体   中英

How Can I find the median of the first values of each list in pyspark dataframe?

values = [(u'[23,4,77,890,455]',10),(u'[11,2,50,1,11]',20),(u'[10,5,1,22,04]',30)]
df = sqlContext.createDataFrame(values,['list','A'])
df.show()

+-----------------+---+
|           list_1|  A|
+-----------------+---+
|[23,4,77,890,455]| 10|
|   [11,2,50,1,11]| 20|
|   [10,5,1,22,04]| 30|
+-----------------+---+

I want to convert the above spark dataframe into a frame such that first element in each list of column "list_1" should be in one column ie 23,11,10 in first column 4,2,5 in second column etc.I tried

df.select([df.list_1[i] for i in range(5)])

But as I have around 4000 values in each list, the above seems to be time consuming. The end goal is to find the median of each column in resulting dataframe.

I use pyspark.

You can have a look at posexplode . I used your small example and transformed the dataframe into another dataframe with the 5 columns and the respective values from the array in each line.

from pyspark.sql.functions import *
df1 = spark.createDataFrame([([23,4,77,890,455],10),([11,2,50,1,11],20),\
([10,5,1,22,04],30)], ["list1","A"])
df1.select(posexplode("list1"),"list1","A")\ #explodes the array and creates multiple rows for each element with the position in the columns "col" and "pos"
.groupBy("list1","A").pivot("pos")\          #group by your initial values and take the "pos" column as pivot to create 1 new column per element here
.agg(max("col")).show(truncate=False)        #collect the values

Output:

+---------------------+---+---+---+---+---+---+
|list1                |A  |0  |1  |2  |3  |4  |
+---------------------+---+---+---+---+---+---+
|[10, 5, 1, 22, 4]    |30 |10 |5  |1  |22 |4  |
|[11, 2, 50, 1, 11]   |20 |11 |2  |50 |1  |11 |
|[23, 4, 77, 890, 455]|10 |23 |4  |77 |890|455|
+---------------------+---+---+---+---+---+---+

Of course afterwards you can continue to compute the mean or whatever you want for the individual array values.

In case that your list1 column contains strings and not the direct array you need to extract the array first. You could do this with regexp_extract and split . It also works for float values in the string.

df1 = spark.createDataFrame([(u'[23.1,4,77,890,455]',10),(u'[11,2,50,1.1,11]',20),(u'[10,5,1,22,04.1]',30)], ["list1","A"])
df1 = df1.withColumn("list2",split(regexp_extract("list1","(([\d\.]+,)+[\d\.]+)",1),","))
df1.select(posexplode("list2"),"list1","A").groupBy("list1","A").pivot("pos").agg(max("col")).show(truncate=False)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM