简体   繁体   English

spark Scala join 聚合表的多次

[英]spark Scala join multi time of an aggregated table

I have two dataframes: one contains the feature of people, another contains the data statistics.我有两个数据框:一个包含人的特征,另一个包含数据统计信息。 Two tables look like this:两个表如下所示:

df_base: df_base:

user_id gender  platform
id1     1       Android
id2     2       Android
id3     1       Android
id4     1       iOS
id5     2       Android

df_time_series: df_time_series:

time_prefix gender  platform    gender_lt7  platform_lt7
m1          1       Android     22486       48185
m1          2       Android     15791       48185
m1          0       Android     18940       48185
m2          0       Android     16401       40852
m2          0       Android     16401       40852
m2          0       iOS         16401       8475
m3          0       Android     15507       39472
m3          1       Android     19205       39472
m3          2       Android     12999       39472

The upper dataframe will left join the next dataframe six times in the double for-loop.上面的 dataframe 将在双 for 循环中六次左加入下一个 dataframe。 Here is the code written by python.这是python编写的代码。

def process_time_feature(df_base, df_time_series):
    for time_prefix in ['m1','m2','m3']: 
        time_df = df_time_series[df_time_series['time_prefix']==time_prefix]
        for key in ['gender','platform']:
            df_key_agg = time_df[[key,key+'_lt7']].groupby(key).agg({key+'_lt7':'mean'}).reset_index()
            df_key_agg = df_key_agg.rename(columns={key+'_lt7':key+'_lt7'+'_'+time_prefix})
            df_base = pd.merge(df_base,df_key_agg,on=key,how='left')
    return df_base

And the result will look like this and six columns have been added:结果将如下所示,并添加了六列:

在此处输入图像描述

Could someone help me to translate the code into Spark Scala in an elegant and efficient way.有人可以帮助我以优雅有效的方式将代码翻译成 Spark Scala。 My scala code looks like the python one above, and always throws the "Out of memory" exception with more keys and more data.我的 scala 代码看起来像上面的 python 代码,并且总是抛出带有更多键和更多数据的“内存不足”异常。

Here is my scala code.这是我的 scala 代码。 The df_base may have 5,000,000 rows with 100 features. df_base 可能有 5,000,000 行和 100 个特征。 Length of keys is 14, so the df_base will join df_time_series 14*3=52 times.键的长度为 14,因此 df_base 将加入 df_time_series 14*3=52 次。

val agg_cols = List("gender", "platform")
val df_time_m1 = df_time_series.filter(col("time_prefix") === "m1")
val df_time_m2 = df_time_series.filter(col("time_prefix") === "m2")
val df_time_m3 = df_time_series.filter(col("time_prefix") === "m3")
val time_prefixs = List("m1", "m2", "m3")
time_prefixs.map(println)
val result = agg_cols.foldLeft(df_base)((df_base, key) => {
  df_base
    .alias("p")
    .join(
      df_time_m1.groupBy(key).agg((key + "_lt7", "mean")).withColumnRenamed("avg(" + key + "_lt7)", key + "_lt7_" + "m1").alias("c"),
      col("p." + key) === col("c." + key), "left_outer"
    ).drop(col("c." + key))
    .join(
      df_time_m2.groupBy(key).agg((key + "_lt7", "mean")).withColumnRenamed("avg(" + key + "_lt7)", key + "_lt7_" + "m2").alias("c"),
      col("p." + key) === col("c." + key), "left_outer"
    ).drop(col("c." + key))
    .join(
      df_time_m3.groupBy(key).agg((key + "_lt7", "mean")).withColumnRenamed("avg(" + key + "_lt7)", key + "_lt7_" + "m3").alias("c"),
      col("p." + key) === col("c." + key), "left_outer"
    ).drop(col("c." + key))
})

You can use pivot on a grouped dataset to create columns from rows.您可以在分组数据集上使用pivot从行创建列。

Given the set up:鉴于设置:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.DataFrame
import spark.implicits._

val base = Seq(
  ("id1", 1, "Android"),
  ("id2", 2, "Android"),
  ("id3", 1, "Android"),
  ("id4", 1, "iOS"),
  ("id5", 2, "Android")
).toDF("user_id", "gender", "platform")

val timeSeries = Seq(
  ("m1", 1, "Android", 22486, 48185),
  ("m1", 2, "Android", 15791, 48185),
  ("m1", 0, "Android", 18940, 48185),
  ("m2", 0, "Android", 16401, 40852),
  ("m2", 0, "Android", 16401, 40852),
  ("m2", 0, "iOS", 16401, 8475),
  ("m3", 0, "Android", 15507, 39472),
  ("m3", 1, "Android", 19205, 39472),
  ("m3", 2, "Android", 12999, 39472)
).toDF("time_prefix", "gender", "platform", "gender_lt7", "platform_lt7")

Then an implementation using pivot could be:那么使用 pivot 的实现可以是:

def createPivot(keyColumn: String, valueColumn: String) = {
  val df = timeSeries.groupBy(col(keyColumn))
      .pivot(col("time_prefix"))
      .agg(first(valueColumn))
  df.select(
    df.columns
      .map(name => if (name == keyColumn) col(name) else col(name).alias(s"${valueColumn}_$name")): _*
  )  
}

val genders = createPivot("gender",  "gender_lt7")
val platforms = createPivot("platform","platform_lt7")

base
  .join(genders, Seq("gender"))
  .join(platforms, Seq("platform"))

Which yields the result:产生结果:

// +--------+------+-------+-------------+-------------+-------------+---------------+---------------+---------------+
// |platform|gender|user_id|gender_lt7_m1|gender_lt7_m2|gender_lt7_m3|platform_lt7_m1|platform_lt7_m2|platform_lt7_m3|
// +--------+------+-------+-------------+-------------+-------------+---------------+---------------+---------------+
// | Android|     1|    id1|        22486|         null|        19205|          48185|          40852|          39472|
// | Android|     2|    id2|        15791|         null|        12999|          48185|          40852|          39472|
// | Android|     1|    id3|        22486|         null|        19205|          48185|          40852|          39472|
// |     iOS|     1|    id4|        22486|         null|        19205|           null|           8475|           null|
// | Android|     2|    id5|        15791|         null|        12999|          48185|          40852|          39472|
// +--------+------+-------+-------------+-------------+-------------+---------------+---------------+---------------+
// 

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM