简体   繁体   中英

Pyspark - How to apply a function only to a subset of columns in a DataFrame?

I wanted to apply a function to some columns of a Spark DataFrame with different methods: fn and a fn1 . Here is how I did that:

def fn(column):
    return(x*2)

udf_1 = udf(fn, DecimalType())

def fn1(column):
    return(x*3)

udf_2 = udf(fn1, DecimalType())
    
def process_df1(df, col_name):
    df1 = df.withColumn(col_name, udf_1(col_name))
    return df1

def process_df2(df, col_name):
    df2 = df.withColumn(col_name, udf_2(col_name))
    return df2

For a single column it works fine. But now I get a list of dict s containing information on various columns:

cols_info = [{'col_name': 'metric_1', 'process': 'True', 'method':'simple'}, {'col_name': 'metric_2', 'process': 'False', 'method':'hash'}] 

How should I parse the cols_info list and apply the above logic only to the columns that have process:True and use a required method ?

The first thing that comes to mind is to filter out columns with process:False

list(filter(lambda col_info: col_info['process'] == 'True', cols_info))

But I'm still missing a more generic approach here.

selectExpr function will be useful here

import pyspark.sql.functions as F
from pyspark.sql.window import Window
#Test data
tst = sqlContext.createDataFrame([(1,2,3,4),(1,3,4,1),(1,4,5,5),(1,6,7,8),(2,1,9,2),(2,2,9,9)],schema=['col1','col2','col3','col4'])    

def fn(x):
    return(x*2)

def fn1(x):
    return(x*3)

sqlContext.udf.register("fn1", fn)
sqlContext.udf.register("fn2", fn1)

cols_info =[{'col_name':'col1','encrypt':False,},{'col_name':'col2','encrypt':True,'method':'fn1'},{'col_name':'col3','encrypt':True,'method':'fn2'}]
# determine which columns have any of the encryption
modified_columns = [x['col_name'] for x in cols_info if x['encrypt']]
# select which colulmns have to be retained
columns_retain = list(set(tst.columns)-set(modified_columns))
#%
expr =columns_retain+[((x['method'])+'('+(x['col_name'])+') as '+ x['col_name']) for x in cols_info if x['encrypt']]
#%
tst_res = tst.selectExpr(*expr)

The results will be:

+----+----+----+----+
|col4|col1|col2|col3|
+----+----+----+----+
|   4|   1|   4|   9|
|   1|   1|   6|  12|
|   5|   1|   8|  15|
|   8|   1|  12|  21|
|   2|   2|   2|  27|
|   9|   2|   4|  27|
+----+----+----+----+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM