[英]Spark dataset count rows matching condition with agg() method (in Java)
我在 Java 中使用 Apache Spark 2.3.1。 我想通过使用Dataset
class 的agg()
方法计算与给定条件匹配的数据集中的行数。
例如,我想计算以下数据集中label
等于1.0
的行数:
SparkSession spark = ...
List<Row> rows = new ArrayList<>();
rows.add(RowFactory.create(0, 0.0));
rows.add(RowFactory.create(1, 1.0));
rows.add(RowFactory.create(2, 1.0));
Dataset<Row> ds =
spark.sqlContext().createDataFrame(rows,
new StructType(new StructField[] {
new StructField("id", DataTypes.LongType, false, Metadata.empty()),
new StructField("label", DataTypes.DoubleType, false, Metadata.empty())}));
我的猜测是使用以下代码:
ds.agg(functions.count(ds.col("label").equalTo(1.0))).show();
但是,显示错误的结果:
+--------------------+
|count((label = 1.0))|
+--------------------+
| 3|
+--------------------+
正确的结果当然应该是2
。
agg()
方法不应该以这种方式工作吗?
agg()
中的计数只会计算不 null 值,因此可以这样做:
import org.apache.spark.sql.functions._
ds.agg(count(when('label.equalTo(1.0),1).otherwise(null))).show()
我在这里找到了这个解决方案https://stackoverflow.com/a/1400115/9687910
agg
方法不应该像这样工作。 实际上,您需要首先根据label对数据进行分组,然后应用计数、最大值等聚合。
df.filter("label".equalTo(1.0)).groupBy('label').agg(count("*").alias("cnt"))
它指的是以下文档。
chlebek 的回答是正确的。
使用 Java 语法:
ds.agg(functions.count(functions.when(ds.col("label").equalTo(1.0), 0))).show();
请注意,使用count
时, when
function 的value
参数无关紧要(相当于 SQL count(*)
)。
另一种实现相同的方法是 output a 1
并将所有结果sum
:
ds.agg(functions.sum(functions.when(ds.col("label").equalTo(1.0), 1))).show();
在这种情况下,该value
必须恰好是1
。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.