繁体   English   中英

根据为 Scala Spark 中列表中的值给出的阈值,将标记添加到 DataFrame 中的列表

[英]Add a tag to the list in the DataFrame based on the threshold given for the values ​in the list in Scala Spark

我有一个数据框,其中有一列“等级”,其中包含具有 2 个字段的等级对象列表:名称(字符串)和值(双精度)。 如果列表中有一个名称为 HOME 且最小值为 20.0 的成绩,我想将 PASS 这个词添加到标签列表中。 下面的例子:

INPUT:
+------+-----+----+-------+-------------------------------------------------------------+
| model| cnd | age| tags  |  grades                                                     |
+------+-----+----+-------+-------------------------------------------------------------+
|  foo1|   xx|  10|  []   |   [{name:"ATW", value: 10.0}, {name:"HOME", value: 20.0}]   | 
|  foo2|   xz|  12|  []   |   [{name:"ATW", value: 70.0}]   | 
|  foo3|   xc|  13|  []   |   [{name:"ATW", value: 90.0}, {name:"HOME", value: 10.0}]    | 
+------+-----+----+-------+-------------------------------------------------------------+



 OUTPUT:

+------+-----+----+-------+--------------------------------------------------------------+
| model| cnd | age| tags  |  grades                                                     |
+------+-----+----+-------+--------------------------------------------------------------+
|  foo1|   xx|  10| [PASS]|   [{name:"ATW", value: 10.0}, {name:"HOME", value: 20.0}]    | 
|  foo2|   xz|  12|  []   |   [{name:"ATW", value: 70.0}]                                | 
|  foo3|   xc|  13|  []   |   [{name:"ATW", value: 90.0}, {name:"HOME", value: 10.0}]    | 
+------+-----+----+-------+--------------------------------------------------------------+

我一直没能找到合理的解决方案。 到目前为止,我得到了这个:

    dataFrame.withColumn("tags",
    when(
      array_contains(
        col("grades.name"),
        lit("HOME")
      ) && col("grades.value") >= lit(20.0),
      array_union(col("tags"), lit(Array("PASS")))
    ).otherwise(col("tags"))

但是这段代码由于某种原因抛出

org.apache.spark.sql.AnalysisException: cannot resolve '(`grades`.`value` >= 20.0D)' due to data type mismatch: differing types in '(`grades`.`value` >= 20.0D)' (array<double> and double).;;

数据是从 bigquery 读取的,值字段中不可能有双精度数组。

假设data称为您的数据集(为简单起见,如下所示):

+----+---------------------------+
|tags|grades                     |
+----+---------------------------+
|[]  |[{ATW, 10.0}, {HOME, 20.0}]|
|[]  |[{ATW, 70.0}]              |
|[]  |[{ATW, 90.0}, {HOME, 10.0}]|
+----+---------------------------+

如果无论如何你的列( grades )是字符串,那么我们可能希望将 JSON 转换为如下结构(你也可以跳过这部分):

data = data.withColumn("grades",
  expr("from_json(grades, 'array<struct<name:string,value:double>>')")
)

一旦到位,我们就可以应用以下内容:

data = data.withColumn("tags",
  when(
    // when this condition is met, meaning that if there is one combo name = HOME and value >= 20
    expr("size(filter(grades, x -> x.name == 'HOME' and x.value >= 20))").geq(1),
    // concatenate whatever there is in TAGS column with array("pass")
    array_union(col("tags"), array(lit("PASS")))
    // otherwise, do not touch TAGS column
  ).otherwise(col("tags")))

最终输出如下所示:

+------+---------------------------+
|tags  |grades                     |
+------+---------------------------+
|[PASS]|[{ATW, 10.0}, {HOME, 20.0}]|
|[]    |[{ATW, 70.0}]              |
|[]    |[{ATW, 90.0}, {HOME, 10.0}]|
+------+---------------------------+

祝你好运!

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM