繁体   English   中英

如何编写用户定义的聚合函数?

[英]How to write a user-defined aggregate function?

我正在尝试了解Java Spark文档。 有一个名为“ 无类型用户定义的聚合函数”的部分 ,其中包含一些我无法理解的示例代码。 这是代码:

package org.apache.spark.examples.sql;

// $example on:untyped_custom_aggregation$
import java.util.ArrayList;
import java.util.List;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// $example off:untyped_custom_aggregation$

public class JavaUserDefinedUntypedAggregation {

  // $example on:untyped_custom_aggregation$
  public static class MyAverage extends UserDefinedAggregateFunction {

    private StructType inputSchema;
    private StructType bufferSchema;

    public MyAverage() {
      List<StructField> inputFields = new ArrayList<>();
      inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
      inputSchema = DataTypes.createStructType(inputFields);

      List<StructField> bufferFields = new ArrayList<>();
      bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
      bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
      bufferSchema = DataTypes.createStructType(bufferFields);
    }
    // Data types of input arguments of this aggregate function
    public StructType inputSchema() {
      return inputSchema;
    }
    // Data types of values in the aggregation buffer
    public StructType bufferSchema() {
      return bufferSchema;
    }
    // The data type of the returned value
    public DataType dataType() {
      return DataTypes.DoubleType;
    }
    // Whether this function always returns the same output on the identical input
    public boolean deterministic() {
      return true;
    }
    // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
    // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
    // the opportunity to update its values. Note that arrays and maps inside the buffer are still
    // immutable.
    public void initialize(MutableAggregationBuffer buffer) {
      buffer.update(0, 0L);
      buffer.update(1, 0L);
    }
    // Updates the given aggregation buffer `buffer` with new input data from `input`
    public void update(MutableAggregationBuffer buffer, Row input) {
      if (!input.isNullAt(0)) {
        long updatedSum = buffer.getLong(0) + input.getLong(0);
        long updatedCount = buffer.getLong(1) + 1;
        buffer.update(0, updatedSum);
        buffer.update(1, updatedCount);
      }
    }
    // Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
      long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
      long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
      buffer1.update(0, mergedSum);
      buffer1.update(1, mergedCount);
    }
    // Calculates the final result
    public Double evaluate(Row buffer) {
      return ((double) buffer.getLong(0)) / buffer.getLong(1);
    }
  }
  // $example off:untyped_custom_aggregation$

  public static void main(String[] args) {
    SparkSession spark = SparkSession
      .builder()
      .appName("Java Spark SQL user-defined DataFrames aggregation example")
      .getOrCreate();

    // $example on:untyped_custom_aggregation$
    // Register the function to access it
    spark.udf().register("myAverage", new MyAverage());

    Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
    df.createOrReplaceTempView("employees");
    df.show();
    // +-------+------+
    // |   name|salary|
    // +-------+------+
    // |Michael|  3000|
    // |   Andy|  4500|
    // | Justin|  3500|
    // |  Berta|  4000|
    // +-------+------+

    Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
    result.show();
    // +--------------+
    // |average_salary|
    // +--------------+
    // |        3750.0|
    // +--------------+
    // $example off:untyped_custom_aggregation$

    spark.stop();
  }
}

我对上述代码的怀疑是:

  • 每当我要创建UDF时,是否都应该具有initializeupdatemerge
  • 什么是变量的意义inputSchemabufferSchema 我很惊讶它们的存在,因为它们从未被用来创建任何DataFrame。 它们应该存在于每个UDF中吗? 如果是,那么它们应该是完全相同的名称吗?
  • 为什么的干将inputSchemabufferSchema未命名getInputSchema()getBufferSchema() 为什么这些变量没有设置器?
  • 这里称为deterministic()的函数的意义是什么? 请给出一个场景,在该场景下调用此函数会很有用。

总的来说,我想知道如何在Spark中编写用户定义的聚合函数。

每当我要创建UDF时,都应该具有初始化,更新和合并功能

UDF代表用户定义的函数,initializeupdatemerge则代表用户定义的聚合函数 (又名UDAF )。

UDF是一种用于单行(通常)产生一行的函数(例如, upper函数)。

UDAF是用于零行或多行以产生一行的函数(例如, count聚合函数)。

您当然不必(也将不能)使用用户定义的函数(UDF)的initializeupdatemerge功能。

使用任何udf 函数来定义和注册UDF。

val myUpper = udf { (s: String) => s.toUpperCase }

如何在Spark中编写用户定义的聚合函数。

什么是变量的意义inputSchemabufferSchema

无耻的插件 :我在UserDefinedAggregateFunction的 Mastering Spark SQL书籍中描述了UDAF —用户定义的无类型聚合函数(UDAF)的合同

引用无类型的用户定义的聚合函数

 // Data types of input arguments of this aggregate function def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil) // Data types of values in the aggregation buffer def bufferSchema: StructType = { StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil) } 

换句话说, inputSchema是你输入什么期望,而bufferSchema是你保持暂时什么,而这样做的聚集。

为什么这些变量没有设置器?

它们是由Spark管理的扩展点。

这里称为deterministic()的函数的意义是什么?

引用无类型的用户定义的聚合函数

 // Whether this function always returns the same output on the identical input def deterministic: Boolean = true 

请给出一个场景,在该场景下调用此函数会很有用。

这是我仍在努力的工作,因此今天无法回答。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM