简体   繁体   中英

How to write multiple WHEN conditions for Spark a dataframe?

I have to join two data frame and select all of its columns based on some condition. Here is an example:

 val sqlContext = new org.apache.spark.sql.SQLContext(sc)

    import sqlContext.implicits._
    import org.apache.spark.{ SparkConf, SparkContext }
    import java.sql.{Date, Timestamp}
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.functions.udf

import org.apache.spark.sql.functions.input_file_name
import org.apache.spark.sql.functions.regexp_extract

val get_cus_val = sqlContext.udf.register("get_cus_val", (filePath: String) => filePath.split("\\.")(3))

val rdd = sc.textFile("s3://trfsmallfffile/FinancialLineItem/MAIN")
val header = rdd.filter(_.contains("LineItem.organizationId")).map(line => line.split("\\|\\^\\|")).first()
val schema = StructType(header.map(cols => StructField(cols.replace(".", "_"), StringType)).toSeq)
val data = sqlContext.createDataFrame(rdd.filter(!_.contains("LineItem.organizationId")).map(line => Row.fromSeq(line.split("\\|\\^\\|").toSeq)), schema)

val schemaHeader = StructType(header.map(cols => StructField(cols.replace(".", "."), StringType)).toSeq)
val dataHeader = sqlContext.createDataFrame(rdd.filter(!_.contains("LineItem.organizationId")).map(line => Row.fromSeq(line.split("\\|\\^\\|").toSeq)), schemaHeader)

val df1resultFinal=data.withColumn("DataPartition", get_cus_val(input_file_name))

val rdd1 = sc.textFile("s3://trfsmallfffile/FinancialLineItem/INCR")
val header1 = rdd1.filter(_.contains("LineItem.organizationId")).map(line => line.split("\\|\\^\\|")).first()
val schema1 = StructType(header1.map(cols => StructField(cols.replace(".", "_"), StringType)).toSeq)
val data1 = sqlContext.createDataFrame(rdd1.filter(!_.contains("LineItem.organizationId")).map(line => Row.fromSeq(line.split("\\|\\^\\|").toSeq)), schema1)


import org.apache.spark.sql.expressions._
val windowSpec = Window.partitionBy("LineItem_organizationId", "LineItem_lineItemId").orderBy($"TimeStamp".cast(LongType).desc) 
val latestForEachKey = data1.withColumn("rank", rank().over(windowSpec)).filter($"rank" === 1).drop("rank", "TimeStamp")


val dfMainOutput = df1resultFinal.join(latestForEachKey, Seq("LineItem_organizationId", "LineItem_lineItemId"), "outer")
      .select($"LineItem_organizationId", $"LineItem_lineItemId",
        when($"DataPartition_1".isNotNull, $"DataPartition_1").otherwise($"DataPartition").as("DataPartition"),
        when($"StatementTypeCode_1".isNotNull, $"StatementTypeCode_1").otherwise($"StatementTypeCode").as("StatementTypeCode"),
        when($"LineItemName_1".isNotNull, $"LineItemName_1").otherwise($"LineItemName").as("LineItemName"),
        when($"LocalLanguageLabel_1".isNotNull, $"LocalLanguageLabel_1").otherwise($"LocalLanguageLabel").as("LocalLanguageLabel"),
        when($"FinancialConceptLocal_1".isNotNull, $"FinancialConceptLocal_1").otherwise($"FinancialConceptLocal").as("FinancialConceptLocal"),
        when($"FinancialConceptGlobal_1".isNotNull, $"FinancialConceptGlobal_1").otherwise($"FinancialConceptGlobal").as("FinancialConceptGlobal"),
        when($"IsDimensional_1".isNotNull, $"IsDimensional_1").otherwise($"IsDimensional").as("IsDimensional"),
        when($"InstrumentId_1".isNotNull, $"InstrumentId_1").otherwise($"InstrumentId").as("InstrumentId"),
        when($"LineItemSequence_1".isNotNull, $"LineItemSequence_1").otherwise($"LineItemSequence").as("LineItemSequence"),
        when($"PhysicalMeasureId_1".isNotNull, $"PhysicalMeasureId_1").otherwise($"PhysicalMeasureId").as("PhysicalMeasureId"),
        when($"FinancialConceptCodeGlobalSecondary_1".isNotNull, $"FinancialConceptCodeGlobalSecondary_1").otherwise($"FinancialConceptCodeGlobalSecondary").as("FinancialConceptCodeGlobalSecondary"),
        when($"IsRangeAllowed_1".isNotNull, $"IsRangeAllowed_1").otherwise($"IsRangeAllowed").as("IsRangeAllowed"),
        when($"IsSegmentedByOrigin_1".isNotNull, $"IsSegmentedByOrigin_1").otherwise($"IsSegmentedByOrigin".cast(DataTypes.StringType)).as("IsSegmentedByOrigin"),
        when($"SegmentGroupDescription_1".isNotNull, $"SegmentGroupDescription_1").otherwise($"SegmentGroupDescription").as("SegmentGroupDescription"),
        when($"SegmentChildDescription_1".isNotNull, $"SegmentChildDescription_1").otherwise($"SegmentChildDescription").as("SegmentChildDescription"),
        when($"SegmentChildLocalLanguageLabel_1".isNotNull, $"SegmentChildLocalLanguageLabel_1").otherwise($"SegmentChildLocalLanguageLabel").as("SegmentChildLocalLanguageLabel"),
        when($"LocalLanguageLabel_languageId_1".isNotNull, $"LocalLanguageLabel_languageId_1").otherwise($"LocalLanguageLabel_languageId").as("LocalLanguageLabel_languageId"),
        when($"LineItemName_languageId_1".isNotNull, $"LineItemName_languageId_1").otherwise($"LineItemName_languageId").as("LineItemName_languageId"),
        when($"SegmentChildDescription_languageId_1".isNotNull, $"SegmentChildDescription_languageId_1").otherwise($"SegmentChildDescription_languageId").as("SegmentChildDescription_languageId"),
        when($"SegmentChildLocalLanguageLabel_languageId_1".isNotNull, $"SegmentChildLocalLanguageLabel_languageId_1").otherwise($"SegmentChildLocalLanguageLabel_languageId").as("SegmentChildLocalLanguageLabel_languageId"),
        when($"SegmentGroupDescription_languageId_1".isNotNull, $"SegmentGroupDescription_languageId_1").otherwise($"SegmentGroupDescription_languageId").as("SegmentGroupDescription_languageId"),
        when($"SegmentMultipleFundbDescription_1".isNotNull, $"SegmentMultipleFundbDescription_1").otherwise($"SegmentMultipleFundbDescription").as("SegmentMultipleFundbDescription"),
        when($"SegmentMultipleFundbDescription_languageId_1".isNotNull, $"SegmentMultipleFundbDescription_languageId_1").otherwise($"SegmentMultipleFundbDescription_languageId").as("SegmentMultipleFundbDescription_languageId"),
        when($"IsCredit_1".isNotNull, $"IsCredit_1").otherwise($"IsCredit").as("IsCredit"),
        when($"FinancialConceptLocalId_1".isNotNull, $"FinancialConceptLocalId_1").otherwise($"FinancialConceptLocalId").as("FinancialConceptLocalId"),
        when($"FinancialConceptGlobalId_1".isNotNull, $"FinancialConceptGlobalId_1").otherwise($"FinancialConceptGlobalId").as("FinancialConceptGlobalId"),
        when($"FinancialConceptCodeGlobalSecondaryId_1".isNotNull, $"FinancialConceptCodeGlobalSecondaryId_1").otherwise($"FinancialConceptCodeGlobalSecondaryId").as("FinancialConceptCodeGlobalSecondaryId"),
        when($"FFAction_1".isNotNull, $"FFAction_1").otherwise($"FFAction|!|").as("FFAction|!|"))
        .filter(!$"FFAction|!|".contains("D|!|"))

val dfMainOutputFinal = dfMainOutput.na.fill("").select($"DataPartition",$"StatementTypeCode",concat_ws("|^|", dfMainOutput.schema.fieldNames.filter(_ != "DataPartition").map(c => col(c)): _*).as("concatenated"))

val headerColumn = dataHeader.columns.toSeq

val header = headerColumn.mkString("", "|^|", "|!|").dropRight(3)

val dfMainOutputFinalWithoutNull = dfMainOutputFinal.withColumn("concatenated", regexp_replace(col("concatenated"), "|^|null", "")).withColumnRenamed("concatenated", header)


dfMainOutputFinalWithoutNull.write.partitionBy("DataPartition","StatementTypeCode")
  .format("csv")
  .option("nullValue", "")
  .option("delimiter", "\t")
  .option("quote", "\u0000")
  .option("header", "true")
  .option("codec", "gzip")
  .save("s3://trfsmallfffile/FinancialLineItem/output")

Now I have to write when condition for all columns explicitly. Is there any way not to repeat when condition for all columns?

In my condition null value of columns comes null as String .So applying coalesce can be difficult .

Here is Data Frame one .

LineItem.organizationId|^|LineItem.lineItemId|^|StatementTypeCode|^|LineItemName|^|LocalLanguageLabel|^|FinancialConceptLocal|^|FinancialConceptGlobal|^|IsDimensional|^|InstrumentId|^|LineItemSequence|^|PhysicalMeasureId|^|FinancialConceptCodeGlobalSecondary|^|IsRangeAllowed|^|IsSegmentedByOrigin|^|SegmentGroupDescription|^|SegmentChildDescription|^|SegmentChildLocalLanguageLabel|^|LocalLanguageLabel.languageId|^|LineItemName.languageId|^|SegmentChildDescription.languageId|^|SegmentChildLocalLanguageLabel.languageId|^|SegmentGroupDescription.languageId|^|SegmentMultipleFundbDescription|^|SegmentMultipleFundbDescription.languageId|^|IsCredit|^|FinancialConceptLocalId|^|FinancialConceptGlobalId|^|FinancialConceptCodeGlobalSecondaryId|^|FFAction|!|
4295879842|^|1246|^|CUS|^|Net Sales-Customer Segment|^|相手先別の販売高(相手先別)|^|JCSNTS|^|REXM|^|False|^||^||^||^||^|False|^|False|^|CUS_JCSNTS|^||^||^|505126|^|505074|^|505074|^|505126|^|505126|^||^|505074|^|True|^|3020155|^|3015249|^||^|I|!|

Here is my Data Frame 2 .

DataPartition_1|^|TimeStamp|^|LineItem.organizationId|^|LineItem.lineItemId|^|StatementTypeCode_1|^|LineItemName_1|^|LocalLanguageLabel_1|^|FinancialConceptLocal_1|^|FinancialConceptGlobal_1|^|IsDimensional_1|^|InstrumentId_1|^|LineItemSequence_1|^|PhysicalMeasureId_1|^|FinancialConceptCodeGlobalSecondary_1|^|IsRangeAllowed_1|^|IsSegmentedByOrigin_1|^|SegmentGroupDescription_1|^|SegmentChildDescription_1|^|SegmentChildLocalLanguageLabel_1|^|LocalLanguageLabel.languageId_1|^|LineItemName.languageId_1|^|SegmentChildDescription.languageId_1|^|SegmentChildLocalLanguageLabel.languageId_1|^|SegmentGroupDescription.languageId_1|^|SegmentMultipleFundbDescription_1|^|SegmentMultipleFundbDescription.languageId_1|^|IsCredit_1|^|FinancialConceptLocalId_1|^|FinancialConceptGlobalId_1|^|FinancialConceptCodeGlobalSecondaryId_1|^|FFAction_1
SelfSourcedPublic|^|1511869196612|^|4295902451|^|10|^|BAL|^|Short term notes payable - related party|^|null|^|null|^|LSOD|^|false|^|null|^|null|^|null|^|null|^|false|^|false|^|null|^|null|^|null|^|null|^|505074|^|null|^|null|^|null|^|null|^|null|^|null|^|null|^|3019157|^|null|^|I|!|

This is what i have tried so far

println("Enterin In to Spark Mode ")

    val conf = new SparkConf().setAppName("FinanicalLineItem").setMaster("local");
    val sc = new SparkContext(conf); //Creating spark context
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)


    val mainFileURL = "C://Users//u6034690//Desktop//SPARK//trfsmallfffile//FinancialLineItem//MAIN"
    val incrFileURL = "C://Users//u6034690//Desktop//SPARK//trfsmallfffile//FinancialLineItem//INCR"
    val outputFileURL = "C://Users//u6034690//Desktop//SPARK//trfsmallfffile//FinancialLineItem//output"
    val descrFileURL = "C://Users//u6034690//Desktop//SPARK//trfsmallfffile//FinancialLineItem//Descr"

    val src = new Path(outputFileURL)
    val dest = new Path(mainFileURL)
    val hadoopconf = sc.hadoopConfiguration
    val fs = src.getFileSystem(hadoopconf)

    sc.hadoopConfiguration.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")

    sc.hadoopConfiguration.set("parquet.enable.summary-metadata", "false")

    myUtil.Utility.DeleteOuptuFolder(fs, outputFileURL)
    myUtil.Utility.DeleteDescrFolder(fs, descrFileURL)

    import sqlContext.implicits._

    val rdd = sc.textFile(mainFileURL)
    val header = rdd.filter(_.contains("LineItem.organizationId")).map(line => line.split("\\|\\^\\|")).first()
    val schema = StructType(header.map(cols => StructField(cols.replace(".", "_"), StringType)).toSeq)
    val data = sqlContext.createDataFrame(rdd.filter(!_.contains("LineItem.organizationId")).map(line => Row.fromSeq(line.split("\\|\\^\\|").toSeq)), schema)

    val schemaHeader = StructType(header.map(cols => StructField(cols.replace(".", "."), StringType)).toSeq)
    val dataHeader = sqlContext.createDataFrame(rdd.filter(!_.contains("LineItem.organizationId")).map(line => Row.fromSeq(line.split("\\|\\^\\|").toSeq)), schemaHeader)

    val get_cus_val = sqlContext.udf.register("get_cus_val", (filePath: String) => filePath.split("\\.")(3))

    val columnsNameArray = schema.fieldNames

    val df1resultFinal = data.withColumn("DataPartition", get_cus_val(input_file_name))
    val rdd1 = sc.textFile(incrFileURL)
    val header1 = rdd1.filter(_.contains("LineItem.organizationId")).map(line => line.split("\\|\\^\\|")).first()
    val schema1 = StructType(header1.map(cols => StructField(cols.replace(".", "_"), StringType)).toSeq)
    val data1 = sqlContext.createDataFrame(rdd1.filter(!_.contains("LineItem.organizationId")).map(line => Row.fromSeq(line.split("\\|\\^\\|").toSeq)), schema1)

    val windowSpec = Window.partitionBy("LineItem_organizationId", "LineItem_lineItemId").orderBy($"TimeStamp".cast(LongType).desc)
    val latestForEachKey = data1.withColumn("rank", rank().over(windowSpec)).filter($"rank" === 1).drop("rank", "TimeStamp")

    val columnMap = latestForEachKey.columns
      .filter(_.endsWith("_1"))
      .map(c => c -> c.dropRight(2))
      .toMap + ("FFAction_1" -> "FFAction|!|")


        val exprs = columnMap.map(t => coalesce(col(s"${t._1}"), col(s"${t._2}")).as(s"${t._2}"))
        val exprsExtended = exprs ++ Array(col("LineItem_organizationId"), col("LineItem_lineItemId"))
        println(exprsExtended)
        val df2 = data.select(exprsExtended: _*)//This line has compilation issue .

type mismatch; found : scala.collection.immutable.Iterable[org.apache.spark.sql.Column] required: Seq[?]

Also when i printed exprsExtended i am getting `` in my output columns

coalesce(LineItemSequence_1, LineItemSequence) AS `LineItemSequence`,

The first step would be to create an list of tuples with the column names in all your when clauses. It can be done in many ways, but if all columns in the dataframe are to be used it can be done as follows (with example dataframe):

val df = Seq(("1", "2", null, "4", "5", "6"), 
    (null, "2", "3", "4", null, "6"), 
    ("1", "2", "3", "4", null, "6"))
  .toDF("col1_1", "col1", "col2_1", "col2", "col3_1", "col3|!|")

val columnMap = df.columns.grouped(2).map(a => (a(0), a(1))).toArray

Now the columnMap variable contains the columns to use as tuples:

("col1_1", "col1")
("col1_2", "col2")
("col1_3", "col3|!|")

The next step is to build an expression that can be used in a select statement using the columnMap variable:

val exprs = columnMap.map(t => coalesce(col(s"${t._1}"), col(s"${t._2}")).as(s"${t._2}"))

and apply the expression to the dataframe:

val df2 = df.select(exprs:_*)

The final result is as follows:

+----+----+-------+
|col1|col2|col3|!||
+----+----+-------+
|   1|   4|      5|
|   2|   3|      6|
|   1|   3|      6|
+----+----+-------+

Note : If there are other columns that should be selected in addition to those in the exprs variable, simply add those as follows:

val exprsExtended = exprs ++ Array(col("other_column1), col("other_column2"))
val df2 = df.select(exprsExtended :_*)

Edit: To create the columnMap in this specific case, with column names like this, starting with all the columns with _1 suffix seems easiest. Before the join , get these columns from the latestForEachKey dataframe:

val columnMap = latestForEachKey.columns 
  .filter(c => c.endsWith("_1") & c != "FFAction_1") 
  .map(c => c -> c.dropRight(2)) :+ ("FFAction_1", "FFAction|!|")

Then create and use exprs and exprsExtended as described above:

val exprs = columnMap.map(t => coalesce(col(s"${t._1}"), col(s"${t._2}")).as(s"${t._2}"))
val exprsExtended = exprs ++ Array(col("LineItem_organizationId"), col("LineItem_lineItemId"))
val df2 = df.select(exprsExtended:_*)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM