简体   繁体   English

使用 Java 每行添加多个任意 Spark 列

[英]Adding multiple, arbitrary Spark columns per row using Java

I figured out how to get a Spark UDF to return a fixed number of columns known ahead of time.我想出了如何让 Spark UDF 返回预先知道的固定数量的列。 But how can a Spark UDF return an arbitrary number of columns which might be different for each row?但是,Spark UDF 如何返回每行可能不同的任意数量的列?

I'm using Spark 3.3.0 with Java 17. Let's say I have a DataFrame containing 1,000,000 people.我正在使用带有 Java 17 的 Spark 3.3.0。假设我有一个包含 1,000,000 人的 DataFrame。 For each person I want to look up each year's salary (eg from a database), but each person might have salaries for different years available.对于每个人,我想查找每年的薪水(例如从数据库中),但每个人可能有不同年份的薪水。 If I knew just had years 2020 and 2021 available, I would do this:如果我知道只有 2020 年和 2021 年可用,我会这样做:

StructType salarySchema = createStructType(List.of(createStructField("salary2020",
    createDecimalType(12, 2), true), createStructField("salary2021",
    createDecimalType(12, 2), true)));
UserDefinedFunction lookupSalariesForId = udf((String id) -> {
  // TODO look up salaries
  return RowFactory.create(salary2020, salary2021);
}, salarySchema).asNondeterministic();

df = df.withColumn("salaries", lookupSalariesForId.apply(col("id")))
    .select("*", "salaries.*");

That is Spark's roundabout way of loading multiple values from a UDF into a single column and then splitting them out into separate columns.这是 Spark 的迂回方式,将 UDF 中的多个值加载到单个列中,然后将它们拆分为单独的列。

So what if one person only has salaries from 2003 and 2004, while another person has salaries from 2007, 2008, and 2009?那么如果一个人只有 2003 年和 2004 年的工资,而另一个人有 2007 年、2008 年和 2009 年的工资呢? I would want to create columns salary2003 and salary2004 for the first person;我想为第一人创建列salary2003salary2004 and then salary2007 , salary2008 , salary2009 for the second person.然后是第二个人的salary2007salary2008salary2009 How would I do that with a UDF?我将如何使用 UDF 做到这一点? (I know how to dynamically create an array to pass back via RowFactory.create() . The problem is that the schema related to the UDF schema is defined outside the UDF logic.) (我知道如何动态创建一个数组以通过RowFactory.create()传回。问题是与 UDF 模式相关的模式是在 UDF 逻辑之外定义的。)

Or is there some better approach altogether with Spark?或者 Spark 有什么更好的方法吗? Should I be creating a separate lookup DataFrame altogether of just person IDs and a column for each possible salary year, and then join them somehow, like we would do in the relational database world?我是否应该创建一个单独的查找 DataFrame 完全只有人员 ID 和每个可能的薪水年的列,然后以某种方式加入它们,就像我们在关系数据库世界中所做的那样? But what benefit would a separate DataFrame give me, and wouldn't I be back to square one to construct it?但是单独的 DataFrame 会给我带来什么好处,我不会回到原点来构建它吗? Of course I could construct it manually in Java, but I wouldn't gain the benefit of the Spark engine, parallel executors, etc.当然我可以在 Java 中手动构建它,但我不会获得 Spark 引擎、并行执行器等的好处。

In short, what is the best way in Spark to dynamically add an arbitrary number of columns for each row in an existing DataFrame, based upon each row's identifier ?简而言之,在 Spark 中,根据每行的标识符为现有 DataFrame 中的每一行动态添加任意数量的列的最佳方法是什么

You can return a map from the UDF:您可以从 UDF 返回 map:

MapType salarySchema2 = createMapType(StringType, createDecimalType(12, 2));
UserDefinedFunction lookupSalariesForId2 = udf((String id) -> {
    //the Java map containg the result of the UDF
    Map<String, BigDecimal> result = new HashMap<String, BigDecimal>();
    //Generate some random test data
    Random r = new Random();
    int years = r.nextInt(10) + 1; // max 10 years
    int startYear = r.nextInt(5) + 2010;
    for (int i = 0; i < years; i++) {
        result.put("salary" + (startYear + i), new BigDecimal(r.nextDouble() * 1000));
    }
    return result;
}, salarySchema2).asNondeterministic();

df = df.withColumn("salaries2", lookupSalariesForId2.apply(col("id"))).cache();
df.show(false);

Output: Output:

+---+---------------------------------------------------------------------------------------------------------------------------------------------------------+
|id |salaries2                                                                                                                                                |
+---+---------------------------------------------------------------------------------------------------------------------------------------------------------+
|1  |{salary2014 -> 333.74}                                                                                                                                   |
|2  |{salary2010 -> 841.83, salary2011 -> 764.24, salary2012 -> 703.35, salary2013 -> 727.06, salary2014 -> 314.52}                                           |
|3  |{salary2012 -> 770.90, salary2013 -> 790.92}                                                                                                             |
|4  |{salary2011 -> 696.24, salary2012 -> 420.56, salary2013 -> 566.10, salary2014 -> 160.99}                                                                 |
|5  |{salary2011 -> 60.59, salary2012 -> 313.57, salary2013 -> 770.82, salary2014 -> 641.90, salary2015 -> 776.13, salary2016 -> 145.28, salary2017 -> 216.02}|
|6  |{salary2011 -> 842.02, salary2012 -> 565.32}                                                                                                             |
+---+---------------------------------------------------------------------------------------------------------------------------------------------------------+

The reason for the cache in the second last line is the second part: Using some sql functions it is possible to get a (Java) collection of all keys in the map column.在倒数第二行cache的原因是第二部分:使用一些sql 函数可以获得 map 列中所有键的(Java)集合。 This collection can then be used to create a single column for each year:然后可以使用此集合为每年创建一个列:

Collection<String> years = JavaConverters.asJavaCollection((WrappedArray<String>)
        df.withColumn("years", functions.map_keys(col("salaries2")))
                .agg(functions.array_sort(
                        functions.array_distinct(
                                functions.flatten(
                                        functions.collect_set(col("years"))))))
                .first().get(0));

List<Column> salaries2 = years.stream().map((year) -> 
    col("salaries2").getItem(year).alias(year)).collect(Collectors.toList());
salaries2.add(0, col("id"));

df.select(salaries2.toArray(new Column[0])).show();

Output: Output:

+---+----------+----------+----------+----------+----------+----------+----------+----------+
| id|salary2010|salary2011|salary2012|salary2013|salary2014|salary2015|salary2016|salary2017|
+---+----------+----------+----------+----------+----------+----------+----------+----------+
|  1|      null|      null|      null|      null|    333.74|      null|      null|      null|
|  2|    841.83|    764.24|    703.35|    727.06|    314.52|      null|      null|      null|
|  3|      null|      null|    770.90|    790.92|      null|      null|      null|      null|
|  4|      null|    696.24|    420.56|    566.10|    160.99|      null|      null|      null|
|  5|      null|     60.59|    313.57|    770.82|    641.90|    776.13|    145.28|    216.02|
|  6|      null|    842.02|    565.32|      null|      null|      null|      null|      null|
+---+----------+----------+----------+----------+----------+----------+----------+----------+

Collecting all years in all maps could take some time on large datasets as Spark has to process all UDF calls first and then collect the map keys.在大型数据集上收集所有地图中的所有年份可能需要一些时间,因为 Spark 必须首先处理所有 UDF 调用,然后收集 map 密钥。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM