[英]How to define and use a User-Defined Aggregate Function in Spark SQL?
我知道如何在 Spark SQL 中编写 UDF:
def belowThreshold(power: Int): Boolean = {
return power < -40
}
sqlContext.udf.register("belowThreshold", belowThreshold _)
我可以做一些类似的事情来定义一个聚合函数吗? 这是怎么做的?
对于上下文,我想运行以下 SQL 查询:
val aggDF = sqlContext.sql("""SELECT span, belowThreshold(opticalReceivePower), timestamp
FROM ifDF
WHERE opticalReceivePower IS NOT null
GROUP BY span, timestamp
ORDER BY span""")
它应该返回类似
Row(span1, false, T0)
我希望聚合函数告诉我在span
和timestamp
定义的组中是否有任何低于阈值的opticalReceivePower
值。 我是否需要以与上面粘贴的 UDF 不同的方式编写我的 UDAF?
火花 >= 3.0
Scala UserDefinedAggregateFunction
正在被弃用 ( SPARK-30423 Deprecate UserDefinedAggregateFunction ) 以支持注册的Aggregator
。
火花 >= 2.3
矢量化 udf(仅限 Python):
from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType
from pyspark.sql.types import *
import pandas as pd
df = sc.parallelize([
("a", 0), ("a", 1), ("b", 30), ("b", -50)
]).toDF(["group", "power"])
def below_threshold(threshold, group="group", power="power"):
@pandas_udf("struct<group: string, below_threshold: boolean>", PandasUDFType.GROUPED_MAP)
def below_threshold_(df):
df = pd.DataFrame(
df.groupby(group).apply(lambda x: (x[power] < threshold).any()))
df.reset_index(inplace=True, drop=False)
return df
return below_threshold_
用法示例:
df.groupBy("group").apply(below_threshold(-40)).show()
## +-----+---------------+
## |group|below_threshold|
## +-----+---------------+
## | b| true|
## | a| false|
## +-----+---------------+
另请参阅在 PySpark 中的 GroupedData 上应用 UDF(使用功能 python 示例)
Spark >= 2.0 (可选 1.6,但 API 略有不同):
可以在类型化Datasets
上使用Aggregators
:
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}
class BelowThreshold[I](f: I => Boolean) extends Aggregator[I, Boolean, Boolean]
with Serializable {
def zero = false
def reduce(acc: Boolean, x: I) = acc | f(x)
def merge(acc1: Boolean, acc2: Boolean) = acc1 | acc2
def finish(acc: Boolean) = acc
def bufferEncoder: Encoder[Boolean] = Encoders.scalaBoolean
def outputEncoder: Encoder[Boolean] = Encoders.scalaBoolean
}
val belowThreshold = new BelowThreshold[(String, Int)](_._2 < - 40).toColumn
df.as[(String, Int)].groupByKey(_._1).agg(belowThreshold)
火花 >= 1.5 :
在 Spark 1.5 中,您可以像这样创建 UDAF,尽管这很可能是矫枉过正:
import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
object belowThreshold extends UserDefinedAggregateFunction {
// Schema you get as an input
def inputSchema = new StructType().add("power", IntegerType)
// Schema of the row which is used for aggregation
def bufferSchema = new StructType().add("ind", BooleanType)
// Returned type
def dataType = BooleanType
// Self-explaining
def deterministic = true
// zero value
def initialize(buffer: MutableAggregationBuffer) = buffer.update(0, false)
// Similar to seqOp in aggregate
def update(buffer: MutableAggregationBuffer, input: Row) = {
if (!input.isNullAt(0))
buffer.update(0, buffer.getBoolean(0) | input.getInt(0) < -40)
}
// Similar to combOp in aggregate
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
buffer1.update(0, buffer1.getBoolean(0) | buffer2.getBoolean(0))
}
// Called on exit to get return value
def evaluate(buffer: Row) = buffer.getBoolean(0)
}
用法示例:
df
.groupBy($"group")
.agg(belowThreshold($"power").alias("belowThreshold"))
.show
// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// | a| false|
// | b| true|
// +-----+--------------+
Spark 1.4 解决方法:
我不确定我是否正确理解了您的要求,但据我所知,简单的旧聚合在这里应该足够了:
val df = sc.parallelize(Seq(
("a", 0), ("a", 1), ("b", 30), ("b", -50))).toDF("group", "power")
df
.withColumn("belowThreshold", ($"power".lt(-40)).cast(IntegerType))
.groupBy($"group")
.agg(sum($"belowThreshold").notEqual(0).alias("belowThreshold"))
.show
// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// | a| false|
// | b| true|
// +-----+--------------+
火花 <= 1.4 :
据我所知,目前(Spark 1.4.1),除了 Hive 之外,不支持 UDAF。 Spark 1.5 应该可以实现(参见SPARK-3947 )。
Spark 在内部使用了许多类,包括ImperativeAggregates
和DeclarativeAggregates
。
有用于内部使用,可能会更改,恕不另行通知,所以它可能不是你想在你的生产代码中使用的东西,但只是为了完整性, DeclarativeAggregate
BelowThreshold
可以这样实现(用Spark 2.2-SNAPSHOT测试):
import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
case class BelowThreshold(child: Expression, threshold: Expression)
extends DeclarativeAggregate {
override def children: Seq[Expression] = Seq(child, threshold)
override def nullable: Boolean = false
override def dataType: DataType = BooleanType
private lazy val belowThreshold = AttributeReference(
"belowThreshold", BooleanType, nullable = false
)()
// Used to derive schema
override lazy val aggBufferAttributes = belowThreshold :: Nil
override lazy val initialValues = Seq(
Literal(false)
)
override lazy val updateExpressions = Seq(Or(
belowThreshold,
If(IsNull(child), Literal(false), LessThan(child, threshold))
))
override lazy val mergeExpressions = Seq(
Or(belowThreshold.left, belowThreshold.right)
)
override lazy val evaluateExpression = belowThreshold
override def defaultResult: Option[Literal] = Option(Literal(false))
}
它应该用等效的withAggregateFunction
进一步包装。
在 Spark(3.0+) Java 中定义和使用 UDF:
private static UDF1<Integer, Boolean> belowThreshold = (power) -> power < -40;
注册 UDF:
SparkSession.builder()
.appName(appName)
.master(master)
.getOrCreate().udf().register("belowThreshold", belowThreshold, BooleanType);
通过 Spark SQL 使用 UDF:
spark.sql("SELECT belowThreshold('50')");
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.