简体   繁体   中英

Adding multiple, arbitrary Spark columns per row using Java

I figured out how to get a Spark UDF to return a fixed number of columns known ahead of time. But how can a Spark UDF return an arbitrary number of columns which might be different for each row?

I'm using Spark 3.3.0 with Java 17. Let's say I have a DataFrame containing 1,000,000 people. For each person I want to look up each year's salary (eg from a database), but each person might have salaries for different years available. If I knew just had years 2020 and 2021 available, I would do this:

StructType salarySchema = createStructType(List.of(createStructField("salary2020",
    createDecimalType(12, 2), true), createStructField("salary2021",
    createDecimalType(12, 2), true)));
UserDefinedFunction lookupSalariesForId = udf((String id) -> {
  // TODO look up salaries
  return RowFactory.create(salary2020, salary2021);
}, salarySchema).asNondeterministic();

df = df.withColumn("salaries", lookupSalariesForId.apply(col("id")))
    .select("*", "salaries.*");

That is Spark's roundabout way of loading multiple values from a UDF into a single column and then splitting them out into separate columns.

So what if one person only has salaries from 2003 and 2004, while another person has salaries from 2007, 2008, and 2009? I would want to create columns salary2003 and salary2004 for the first person; and then salary2007 , salary2008 , salary2009 for the second person. How would I do that with a UDF? (I know how to dynamically create an array to pass back via RowFactory.create() . The problem is that the schema related to the UDF schema is defined outside the UDF logic.)

Or is there some better approach altogether with Spark? Should I be creating a separate lookup DataFrame altogether of just person IDs and a column for each possible salary year, and then join them somehow, like we would do in the relational database world? But what benefit would a separate DataFrame give me, and wouldn't I be back to square one to construct it? Of course I could construct it manually in Java, but I wouldn't gain the benefit of the Spark engine, parallel executors, etc.

In short, what is the best way in Spark to dynamically add an arbitrary number of columns for each row in an existing DataFrame, based upon each row's identifier ?

You can return a map from the UDF:

MapType salarySchema2 = createMapType(StringType, createDecimalType(12, 2));
UserDefinedFunction lookupSalariesForId2 = udf((String id) -> {
    //the Java map containg the result of the UDF
    Map<String, BigDecimal> result = new HashMap<String, BigDecimal>();
    //Generate some random test data
    Random r = new Random();
    int years = r.nextInt(10) + 1; // max 10 years
    int startYear = r.nextInt(5) + 2010;
    for (int i = 0; i < years; i++) {
        result.put("salary" + (startYear + i), new BigDecimal(r.nextDouble() * 1000));
    }
    return result;
}, salarySchema2).asNondeterministic();

df = df.withColumn("salaries2", lookupSalariesForId2.apply(col("id"))).cache();
df.show(false);

Output:

+---+---------------------------------------------------------------------------------------------------------------------------------------------------------+
|id |salaries2                                                                                                                                                |
+---+---------------------------------------------------------------------------------------------------------------------------------------------------------+
|1  |{salary2014 -> 333.74}                                                                                                                                   |
|2  |{salary2010 -> 841.83, salary2011 -> 764.24, salary2012 -> 703.35, salary2013 -> 727.06, salary2014 -> 314.52}                                           |
|3  |{salary2012 -> 770.90, salary2013 -> 790.92}                                                                                                             |
|4  |{salary2011 -> 696.24, salary2012 -> 420.56, salary2013 -> 566.10, salary2014 -> 160.99}                                                                 |
|5  |{salary2011 -> 60.59, salary2012 -> 313.57, salary2013 -> 770.82, salary2014 -> 641.90, salary2015 -> 776.13, salary2016 -> 145.28, salary2017 -> 216.02}|
|6  |{salary2011 -> 842.02, salary2012 -> 565.32}                                                                                                             |
+---+---------------------------------------------------------------------------------------------------------------------------------------------------------+

The reason for the cache in the second last line is the second part: Using some sql functions it is possible to get a (Java) collection of all keys in the map column. This collection can then be used to create a single column for each year:

Collection<String> years = JavaConverters.asJavaCollection((WrappedArray<String>)
        df.withColumn("years", functions.map_keys(col("salaries2")))
                .agg(functions.array_sort(
                        functions.array_distinct(
                                functions.flatten(
                                        functions.collect_set(col("years"))))))
                .first().get(0));

List<Column> salaries2 = years.stream().map((year) -> 
    col("salaries2").getItem(year).alias(year)).collect(Collectors.toList());
salaries2.add(0, col("id"));

df.select(salaries2.toArray(new Column[0])).show();

Output:

+---+----------+----------+----------+----------+----------+----------+----------+----------+
| id|salary2010|salary2011|salary2012|salary2013|salary2014|salary2015|salary2016|salary2017|
+---+----------+----------+----------+----------+----------+----------+----------+----------+
|  1|      null|      null|      null|      null|    333.74|      null|      null|      null|
|  2|    841.83|    764.24|    703.35|    727.06|    314.52|      null|      null|      null|
|  3|      null|      null|    770.90|    790.92|      null|      null|      null|      null|
|  4|      null|    696.24|    420.56|    566.10|    160.99|      null|      null|      null|
|  5|      null|     60.59|    313.57|    770.82|    641.90|    776.13|    145.28|    216.02|
|  6|      null|    842.02|    565.32|      null|      null|      null|      null|      null|
+---+----------+----------+----------+----------+----------+----------+----------+----------+

Collecting all years in all maps could take some time on large datasets as Spark has to process all UDF calls first and then collect the map keys.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM