繁体   English   中英

使用 agg() 方法的 Spark 数据集计数行匹配条件(Java 中)

[英]Spark dataset count rows matching condition with agg() method (in Java)

我在 Java 中使用 Apache Spark 2.3.1。 我想通过使用Dataset class 的agg()方法计算与给定条件匹配的数据集中的行数。

例如,我想计算以下数据集中label等于1.0的行数:

SparkSession spark = ...

List<Row> rows = new ArrayList<>();
rows.add(RowFactory.create(0, 0.0));
rows.add(RowFactory.create(1, 1.0));
rows.add(RowFactory.create(2, 1.0));

Dataset<Row> ds =
    spark.sqlContext().createDataFrame(rows,
        new StructType(new StructField[] {
            new StructField("id", DataTypes.LongType, false, Metadata.empty()),
            new StructField("label", DataTypes.DoubleType, false, Metadata.empty())}));

我的猜测是使用以下代码:

ds.agg(functions.count(ds.col("label").equalTo(1.0))).show();

但是,显示错误的结果:

+--------------------+
|count((label = 1.0))|
+--------------------+
|                   3|
+--------------------+

正确的结果当然应该是2

agg()方法不应该以这种方式工作吗?

agg()中的计数只会计算不 null 值,因此可以这样做:

 import org.apache.spark.sql.functions._
 ds.agg(count(when('label.equalTo(1.0),1).otherwise(null))).show()

我在这里找到了这个解决方案https://stackoverflow.com/a/1400115/9687910

agg方法不应该像这样工作。 实际上,您需要首先根据label对数据进行分组,然后应用计数最大值等聚合。

df.filter("label".equalTo(1.0)).groupBy('label').agg(count("*").alias("cnt"))

它指的是以下文档

chlebek 的回答是正确的。

使用 Java 语法:

ds.agg(functions.count(functions.when(ds.col("label").equalTo(1.0), 0))).show();

请注意,使用count时, when function 的value参数无关紧要(相当于 SQL count(*) )。

另一种实现相同的方法是 output a 1并将所有结果sum

ds.agg(functions.sum(functions.when(ds.col("label").equalTo(1.0), 1))).show();

在这种情况下,该value必须恰好是1

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM