[英]How to test exceptions thrown inside map in Scala
我有以下 Scala 函數
def throwError(spark: SparkSession,df:DataFrame): Unit = {
import spark.implicits._
throw new IllegalArgumentException(s"Illegal arguments")
val predictionAndLabels = df.select("prediction", "label").map {
case Row(prediction: Double, label: Double) => (prediction, label)
case other => throw new IllegalArgumentException(s"Illegal arguments")
}
predictionAndLabels.show()
}
我想測試上面函數拋出的異常,但是我的測試失敗了。
"Testing" should "throw error for datetype" in withSparkSession {
spark => {
// Creating a dataframe
val someData = Seq(
Row(8, Date.valueOf("2016-09-30")),
Row(9, Date.valueOf("2017-09-30")),
Row(10, Date.valueOf("2018-09-30"))
)
val someSchema = List(
StructField("prediction", IntegerType, true),
StructField("label", DateType , true)
)
val someDF = spark.createDataFrame(
spark.sparkContext.parallelize(someData),
StructType(someSchema)
)
// Testing exception
val caught = intercept[IllegalArgumentException] {
throwError(spark,someDF)
}
assert(caught.getMessage.contains("Illegal arguments"))
}
}
如果我在 map 函數調用之外移動throw new IllegalArgumentException(s"Illegal arguments")
,則測試通過。
如何測試“throwError”函數拋出的異常?
sparkDF 無法在行級別捕獲異常,如果您使用 RDD,則可以實現您想要做的事情。
查看此博客: https : //www.nicolaferraro.me/2016/02/18/exception-handling-in-apache-spark/
您的問題的解決方法:
def throwError(spark: SparkSession,df:DataFrame): Unit = {
import spark.implicits._
val countOfRowsBeforeCheck = df.count()
val predictionAndLabels = df.select("prediction", "label").flatMap {
case Row(prediction: Double, label: Double) => Iterator((prediction, label))
case other => Iterator.empty
}
val countOfRowsAfterCheck = predictionAndLabels.count()
if(countOfRowsAfterCheck != countOfRowsBeforeCheck){
throw new IllegalArgumentException(s"Illegal arguments")
}
predictionAndLabels.show()
}
希望這有幫助!!
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.