繁体   English   中英

如何避免在 Spark (2.4) SQL 中自动转换 ArrayType - Scala 2.11

[英]How to avoid automatic cast for ArrayType in Spark (2.4) SQL - Scala 2.11

鉴于 Spark 2.4 和 scala 2.11 中的代码

val df = spark.sql("""select array(45, "something", 45)""")

如果我使用df.printSchema()打印模式,我会看到 spark 自动转换为 String CAST(45 AS STRING)

root
 |-- array(CAST(45 AS STRING), something, CAST(45 AS STRING)): array (nullable = false)
 |    |-- element: string (containsNull = false)

我想知道是否有办法避免这种自动转换,而是让 Spark SQL 因异常而失败? 假设我在此之后调用任何操作,例如df.collect()

这只是一个查询示例,但它应该适用于任何查询。

这会在数据框中创建一个“ArrayType”列。

来自scaladocsAn ArrayType object comprises two fields, elementType: DataType and containsNull: Boolean. The field of elementType is used to specify the type of array elements. The field of containsNull is used to specify if the array has null values. An ArrayType object comprises two fields, elementType: DataType and containsNull: Boolean. The field of elementType is used to specify the type of array elements. The field of containsNull is used to specify if the array has null values.

因此 ArrayType 只接受 Array 中的一种类型的列。 如果有不同类型的值传递给array函数,它会首先尝试将列转换为字段中最适合的类型。 如果列完全不兼容,那么 Spark 将抛出异常。 下面的例子

val df = spark.sql("""select array(45, 46L, 45.45)""")
df.printSchema()

root
 |-- array(CAST(45 AS DECIMAL(22,2)), CAST(46 AS DECIMAL(22,2)), CAST(45.45 AS DECIMAL(22,2))): array (nullable = false)
 |    |-- element: decimal(22,2) (containsNull = false)

df: org.apache.spark.sql.DataFrame = [array(CAST(45 AS DECIMAL(22,2)), CAST(46 AS DECIMAL(22,2)), CAST(45.45 AS DECIMAL(22,2))): array<decimal(22,2)>]

下面的下一个,错误:

val df = spark.sql("""select array(45, 46L, True)""")
df.printSchema()

org.apache.spark.sql.AnalysisException: cannot resolve 'array(45, 46L, true)' due to data type mismatch: input to function array should all be the same type, but it's [int, bigint, boolean]; line 1 pos 7;
'Project [unresolvedalias(array(45, 46, true), None)]
+- OneRowRelation

    at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:126)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:111)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$6.apply(TreeNode.scala:304)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$6.apply(TreeNode.scala:304)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:77)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:303)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:301)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:301)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$8.apply(TreeNode.scala:354)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:352)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:301)
    at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:94)
    at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:94)
    at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$3.apply(QueryPlan.scala:106)
    at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$3.apply(QueryPlan.scala:106)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:77)

我假设您正在从某个数据框中的列创建一个数组。 在这种情况下,您可以在该数据帧的架构中检查输入列的类型为StringType 在 Scala 中,它看起来像这样:

// some dataframe with a long and a string
val df = spark.range(3).select('id, 'id cast "string" as "id_str")

// a function that checks if the provided columns are strings
def check_df(df : DataFrame, cols : Seq[String]) {
    val non_string_column = df
        .schema
        .find(field => cols.contains(field.name) &&
                              field.dataType != DataTypes.StringType)
    if(non_string_column.isDefined)
        throw new Error(s"The column ${non_string_column.get.name} has type " +
                        s"${non_string_column.get.dataType} instead of StringType")

那我们试试看

scala> check_df(df, Seq("id", "id_str"))
java.lang.Error: The column id has type  LongType instead of string
  at check_def(<console>:36)
  ... 50 elided

scala> check_def(df, Seq("id_str"))
// no exception

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM