[英]How to write a user-defined aggregate function?
我正在尝试了解Java Spark文档。 有一个名为“ 无类型用户定义的聚合函数”的部分 ,其中包含一些我无法理解的示例代码。 这是代码:
package org.apache.spark.examples.sql;
// $example on:untyped_custom_aggregation$
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// $example off:untyped_custom_aggregation$
public class JavaUserDefinedUntypedAggregation {
// $example on:untyped_custom_aggregation$
public static class MyAverage extends UserDefinedAggregateFunction {
private StructType inputSchema;
private StructType bufferSchema;
public MyAverage() {
List<StructField> inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
inputSchema = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
bufferSchema = DataTypes.createStructType(bufferFields);
}
// Data types of input arguments of this aggregate function
public StructType inputSchema() {
return inputSchema;
}
// Data types of values in the aggregation buffer
public StructType bufferSchema() {
return bufferSchema;
}
// The data type of the returned value
public DataType dataType() {
return DataTypes.DoubleType;
}
// Whether this function always returns the same output on the identical input
public boolean deterministic() {
return true;
}
// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable.
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0L);
buffer.update(1, 0L);
}
// Updates the given aggregation buffer `buffer` with new input data from `input`
public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)) {
long updatedSum = buffer.getLong(0) + input.getLong(0);
long updatedCount = buffer.getLong(1) + 1;
buffer.update(0, updatedSum);
buffer.update(1, updatedCount);
}
}
// Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
buffer1.update(0, mergedSum);
buffer1.update(1, mergedCount);
}
// Calculates the final result
public Double evaluate(Row buffer) {
return ((double) buffer.getLong(0)) / buffer.getLong(1);
}
}
// $example off:untyped_custom_aggregation$
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("Java Spark SQL user-defined DataFrames aggregation example")
.getOrCreate();
// $example on:untyped_custom_aggregation$
// Register the function to access it
spark.udf().register("myAverage", new MyAverage());
Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
df.createOrReplaceTempView("employees");
df.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+
Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
// $example off:untyped_custom_aggregation$
spark.stop();
}
}
我对上述代码的怀疑是:
initialize
, update
和merge
? inputSchema
和bufferSchema
? 我很惊讶它们的存在,因为它们从未被用来创建任何DataFrame。 它们应该存在于每个UDF中吗? 如果是,那么它们应该是完全相同的名称吗? inputSchema
和bufferSchema
未命名getInputSchema()
和getBufferSchema()
为什么这些变量没有设置器? deterministic()
的函数的意义是什么? 请给出一个场景,在该场景下调用此函数会很有用。 总的来说,我想知道如何在Spark中编写用户定义的聚合函数。
每当我要创建UDF时,都应该具有初始化,更新和合并功能
UDF代表用户定义的函数,而initialize
, update
和merge
则代表用户定义的聚合函数 (又名UDAF )。
UDF是一种用于单行(通常)产生一行的函数(例如, upper
函数)。
UDAF是用于零行或多行以产生一行的函数(例如, count
聚合函数)。
您当然不必(也将不能)使用用户定义的函数(UDF)的initialize
, update
和merge
功能。
使用任何udf
函数来定义和注册UDF。
val myUpper = udf { (s: String) => s.toUpperCase }
如何在Spark中编写用户定义的聚合函数。
什么是变量的意义
inputSchema
和bufferSchema
?
( 无耻的插件 :我在UserDefinedAggregateFunction的 Mastering Spark SQL书籍中描述了UDAF —用户定义的无类型聚合函数(UDAF)的合同 )
引用无类型的用户定义的聚合函数 :
// Data types of input arguments of this aggregate function def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil) // Data types of values in the aggregation buffer def bufferSchema: StructType = { StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil) }
换句话说, inputSchema
是你输入什么期望,而bufferSchema
是你保持暂时什么,而这样做的聚集。
为什么这些变量没有设置器?
它们是由Spark管理的扩展点。
这里称为
deterministic()
的函数的意义是什么?
引用无类型的用户定义的聚合函数 :
// Whether this function always returns the same output on the identical input def deterministic: Boolean = true
请给出一个场景,在该场景下调用此函数会很有用。
这是我仍在努力的工作,因此今天无法回答。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.