[英]Pyspark: Get the maximum prediction value for each row in a new column from a dense vector column
[英]Get the column index where the value is maximum per each row in pyspark
我有一个 pyspark dataframe 例如:
ID | Col1 | Col2 | Col3 | ... | 科尔N |
---|---|---|---|---|---|
1 | 10 | 5 | 21 | ... | -9 |
2 | 87 | 1 | 1 | ... | 1 |
3 | 1 | 95 | 1 | ... | 1 |
如何创建一个 pyspark dataframe 列MAX
,它表示每行值为最大值的索引列,例如:
ID | Col1 | Col2 | Col3 | ... | 科尔N | 最大限度 |
---|---|---|---|---|---|---|
1 | 10 | 5 | 21 | ... | -9 | 3 |
2 | 87 | 1 | 1 | ... | 1 | 1 |
3 | 1 | 95 | 1 | ... | 1 | 2 |
在每行中创建一个具有最大值的列
列出可以找到最大值的列
消除列表中的 NaN
下面的代码
import pyspark.sql.functions as F
from pyspark.sql import Window
from pyspark.sql.functions import*
w=Window.partitionBy('ID').orderBy().rowsBetween(Window.unboundedPreceding,0)
df=(df.withColumn(
"max",
F.greatest(*[F.col(x) for x in df.columns[1:]])#Find the max in each row
)
.withColumn(
'maxcol', array(*[when(col(c) ==col('max'), lit(c)) for c in df.columns])#Find intersection of max with all other columns
).withColumn(
'maxcol', expr("filter(maxcol, x -> x is not null)")#Filter ou the nans in the intersection
).show())
+---+----+----+----+----+---+------+
| ID|Col1|Col2|Col3|ColN|max|maxcol|
+---+----+----+----+----+---+------+
| 1| 10| 5| 21| -9| 21|[Col3]|
| 2| 87| 1| 1| 1| 87|[Col1]|
| 3| 1| 95| 1| 1| 95|[Col2]|
+---+----+----+----+----+---+------+
你也可以使用 pandas_udf 虽然我不确定 pyspark.sql.functions import pandas_udf 的功效
import pandas as pd
from pyspark.sql.types import *
def max_col(a:pd.DataFrame) -> pd.DataFrame:
s=a.isin(a.iloc[:,1:].max(1))
return a.assign(maxcol=s.agg(lambda x: x.index[x].values, axis=1))
schema=StructType([\
StructField('ID',LongType(),True),\
StructField('Col1',LongType(),True),\
StructField('Col2',LongType(),True),\
StructField('Col3',LongType(),True),\
StructField('ColN',LongType(),True),\
StructField('maxcol',ArrayType(StringType(),True),False)\
])
df.groupby('ID').applyInPandas(max_col, schema).show()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.