簡體   English   中英

pyspark sql函數而不是rdd與眾不同

[英]pyspark sql functions instead of rdd distinct

我一直在嘗試替換特定列的數據集中的字符串。 要么為1要么為0,如果為1,則為'Y';否則為0。

我已經成功地確定了要定位的列,使用帶有lambda的數據幀到rdd的轉換,但是要花一些時間來處理。

每列都要切換到rdd,然后執行一個獨特的操作,這需要一段時間!

如果不同結果集中存在“ Y”,則該列被標識為需要轉換。

我想知道是否有人可以建議我如何專門使用pyspark sql函數來獲得相同的結果,而不必為每個列進行切換?

關於示例數據的代碼如下:

    import pyspark.sql.types as typ
    import pyspark.sql.functions as func

    col_names = [
        ('ALIVE', typ.StringType()),
        ('AGE', typ.IntegerType()),
        ('CAGE', typ.IntegerType()),
        ('CNT1', typ.IntegerType()),
        ('CNT2', typ.IntegerType()),
        ('CNT3', typ.IntegerType()),
        ('HE', typ.IntegerType()),
        ('WE', typ.IntegerType()),
        ('WG', typ.IntegerType()),
        ('DBP', typ.StringType()),
        ('DBG', typ.StringType()),
        ('HT1', typ.StringType()),
        ('HT2', typ.StringType()),
        ('PREV', typ.StringType())
        ]

    schema = typ.StructType([typ.StructField(c[0], c[1], False) for c in col_names])
    df = spark.createDataFrame([('Y',22,56,4,3,65,180,198,18,'N','Y','N','N','N'),
                                ('N',38,79,3,4,63,155,167,12,'N','N','N','Y','N'),
                                ('Y',39,81,6,6,60,128,152,24,'N','N','N','N','Y')]
                               ,schema=schema)

    cols = [(col.name, col.dataType) for col in df.schema]

    transform_cols = []

    for s in cols:
      if s[1] == typ.StringType():
        distinct_result = df.select(s[0]).distinct().rdd.map(lambda row: row[0]).collect()
        if 'Y' in distinct_result:
          transform_cols.append(s[0])

    print(transform_cols)

輸出為:

['ALIVE', 'DBG', 'HT2', 'PREV']

我設法使用udf來完成任務。 首先,選擇帶有YN的列(這里我使用func.first來瀏覽第一行):

cols_sel = df.select([func.first(col).alias(col) for col in df.columns]).collect()[0].asDict()
cols = [col_name for (col_name, v) in cols_sel.items() if v in ['Y', 'N']]
# return ['HT2', 'ALIVE', 'DBP', 'HT1', 'PREV', 'DBG']

接下來,您可以創建udf以地圖功能YN10

def map_input(val):
    map_dict = dict(zip(['Y', 'N'], [1, 0]))
    return map_dict.get(val)
udf_map_input = func.udf(map_input, returnType=typ.IntegerType())

for col in cols:
    df = df.withColumn(col, udf_map_input(col))
df.show()

最后,您可以對列進行求和。 然后,我將輸出轉換為字典並檢查哪些列的值大於0(即包含Y

out = df.select([func.sum(col).alias(col) for col in cols]).collect()
out = out[0]
print([col_name for (col_name, val) in out.asDict().items() if val > 0])

輸出量

['DBG', 'HT2', 'ALIVE', 'PREV']

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM