簡體   English   中英

如何避免在 Spark (2.4) SQL 中自動轉換 ArrayType - Scala 2.11

[英]How to avoid automatic cast for ArrayType in Spark (2.4) SQL - Scala 2.11

鑒於 Spark 2.4 和 scala 2.11 中的代碼

val df = spark.sql("""select array(45, "something", 45)""")

如果我使用df.printSchema()打印模式,我會看到 spark 自動轉換為 String CAST(45 AS STRING)

root
 |-- array(CAST(45 AS STRING), something, CAST(45 AS STRING)): array (nullable = false)
 |    |-- element: string (containsNull = false)

我想知道是否有辦法避免這種自動轉換,而是讓 Spark SQL 因異常而失敗? 假設我在此之后調用任何操作,例如df.collect()

這只是一個查詢示例,但它應該適用於任何查詢。

這會在數據框中創建一個“ArrayType”列。

來自scaladocsAn ArrayType object comprises two fields, elementType: DataType and containsNull: Boolean. The field of elementType is used to specify the type of array elements. The field of containsNull is used to specify if the array has null values. An ArrayType object comprises two fields, elementType: DataType and containsNull: Boolean. The field of elementType is used to specify the type of array elements. The field of containsNull is used to specify if the array has null values.

因此 ArrayType 只接受 Array 中的一種類型的列。 如果有不同類型的值傳遞給array函數,它會首先嘗試將列轉換為字段中最適合的類型。 如果列完全不兼容,那么 Spark 將拋出異常。 下面的例子

val df = spark.sql("""select array(45, 46L, 45.45)""")
df.printSchema()

root
 |-- array(CAST(45 AS DECIMAL(22,2)), CAST(46 AS DECIMAL(22,2)), CAST(45.45 AS DECIMAL(22,2))): array (nullable = false)
 |    |-- element: decimal(22,2) (containsNull = false)

df: org.apache.spark.sql.DataFrame = [array(CAST(45 AS DECIMAL(22,2)), CAST(46 AS DECIMAL(22,2)), CAST(45.45 AS DECIMAL(22,2))): array<decimal(22,2)>]

下面的下一個,錯誤:

val df = spark.sql("""select array(45, 46L, True)""")
df.printSchema()

org.apache.spark.sql.AnalysisException: cannot resolve 'array(45, 46L, true)' due to data type mismatch: input to function array should all be the same type, but it's [int, bigint, boolean]; line 1 pos 7;
'Project [unresolvedalias(array(45, 46, true), None)]
+- OneRowRelation

    at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:126)
    at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:111)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$6.apply(TreeNode.scala:304)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$6.apply(TreeNode.scala:304)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:77)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:303)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:301)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:301)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$8.apply(TreeNode.scala:354)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:352)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:301)
    at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:94)
    at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:94)
    at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$3.apply(QueryPlan.scala:106)
    at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$3.apply(QueryPlan.scala:106)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:77)

我假設您正在從某個數據框中的列創建一個數組。 在這種情況下,您可以在該數據幀的架構中檢查輸入列的類型為StringType 在 Scala 中,它看起來像這樣:

// some dataframe with a long and a string
val df = spark.range(3).select('id, 'id cast "string" as "id_str")

// a function that checks if the provided columns are strings
def check_df(df : DataFrame, cols : Seq[String]) {
    val non_string_column = df
        .schema
        .find(field => cols.contains(field.name) &&
                              field.dataType != DataTypes.StringType)
    if(non_string_column.isDefined)
        throw new Error(s"The column ${non_string_column.get.name} has type " +
                        s"${non_string_column.get.dataType} instead of StringType")

那我們試試看

scala> check_df(df, Seq("id", "id_str"))
java.lang.Error: The column id has type  LongType instead of string
  at check_def(<console>:36)
  ... 50 elided

scala> check_def(df, Seq("id_str"))
// no exception

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM