简体   繁体   中英

Spark Compare two dataframes and find the match count

I have two spark sql dataframs both are not having any unique column. First dataframe contains n-grams, second one contains long text string (blog post). I want to find the matches on df2 and add count in df1.

DF1
------------
words
------------
Stack
Stack Overflow
users
spark scala

DF2

--------
POSTS
--------
Hello, Stack overflow users , Do you know spark scala
Spark scala is very fast
Users in stack are good in spark, users


Expected output

  ------------     ---------------
    words            match_count
  ------------    ---------------

    Stack               2           
    Stack Overflow      1
    users               3
    spark scala         1

Seems that join-groupBy-count will do:

df1
    .join(df2, expr("lower(posts) rlike lower(words)"))
    .groupBy("words")
    .agg(count("*").as("match_count"))

You can use pandas features in pyspark. Here is my solution below

>>> from pyspark.sql import Row
>>> import pandas as pd
>>> 
>>> rdd1 = sc.parallelize(['Stack','Stack Overflow','users','spark scala'])
>>> data1 = rdd1.map(lambda x: Row(x))
>>> df1=spark.createDataFrame(data1,['words'])
>>> df1.show()
+--------------+
|         words|
+--------------+
|         Stack|
|Stack Overflow|
|         users|
|   spark scala|
+--------------+

>>> rdd2 = sc.parallelize([
...     'Hello, Stack overflow users , Do you know spark scala',
...     'Spark scala is very fast',
...     'Users in stack are good in spark'
...     ])
>>> data2 = rdd2.map(lambda x: Row(x))
>>> df2=spark.createDataFrame(data2,['posts'])
>>> df2.show()
+--------------------+
|               posts|
+--------------------+
|Hello, Stack over...|
|Spark scala is ve...|
|Users in stack ar...|
+--------------------+

>>> dfPd1 = df1.toPandas()
>>> dfPd2 = df2.toPandas().apply(lambda x: x.str.lower())
>>> 
>>> words = dict((x,0) for x in dfPd1['words'])
>>> 
>>> for i in words:
...     x = dfPd2['posts'].str.contains(i.lower()).sum()
...     if i in words:
...         words[i] = x
... 
>>> 
>>> words
{'Stack': 2, 'Stack Overflow': 1, 'users': 2, 'spark scala': 2}
>>> 
>>> data = pd.DataFrame.from_dict(words, orient='index').reset_index()
>>> data.columns = ['words','match_count']
>>> 
>>> df = spark.createDataFrame(data)
>>> df.show()
+--------------+-----------+
|         words|match_count|
+--------------+-----------+
|         Stack|          2|
|Stack Overflow|          1|
|         users|          2|
|   spark scala|          2|
+--------------+-----------+

Brute force approach as follows in Scala not working over lines and treating all as lowercase, could all be added but that is for another day. Relies on not trying to examine strings but to define ngrams as that is what it is is, ngrams against ngrams and genning these and then JOINing and counting, whereby inner join only relevant. Some extra data added to prove the matching.

import org.apache.spark.ml.feature._
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.functions._  
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StructField,StructType,IntegerType,ArrayType,LongType,StringType}
import spark.implicits._

// Sample data, duplicates and items to check it works.
val dfPostsInit = Seq(
                  ( "Hello!!, Stack overflow users, Do you know spark scala users."),
                  ( "Spark scala is very fast,"),
                  ( "Users in stack are good in spark"),
                  ( "Users in stack are good in spark"),
                  ( "xy z"),
                  ( "x yz"),
                  ( "ABC"),
                  ( "abc"),
                  ( "XYZ,!!YYY@#$ Hello Bob..."))
                 .toDF("posting")

val dfWordsInit = Seq(("Stack"), ("Stack Overflow"),("users"), ("spark scala"), ("xyz"), ("xy"), ("not found"), ("abc")).toDF("words")
val dfWords     = dfWordsInit.withColumn("words_perm" ,regexp_replace(dfWordsInit("words"), " ", "^")).withColumn("lower_words_perm" ,lower(regexp_replace(dfWordsInit("words"), " ", "^")))

val dfPostsTemp = dfPostsInit.map(r => (r.getString(0), r.getString(0).split("\\W+").toArray )) 
// Tidy Up
val columnsRenamed = Seq("posting", "posting_array") 
val dfPosts = dfPostsTemp.toDF(columnsRenamed: _*)

// Generate Ngrams up to some limit N - needs to be set. This so that we can count properly via a JOIN direct comparison. Can parametrize this in calls below.
// Not easy to find string matching over Array and no other answer presented.
def buildNgrams(inputCol: String = "posting_array", n: Int = 3) = {
  val ngrams = (1 to n).map(i =>
      new NGram().setN(i)
        .setInputCol(inputCol).setOutputCol(s"${i}_grams")
  )
  new Pipeline().setStages((ngrams).toArray)
}

val suffix:String = "_grams"
var i_grams_Cols:List[String] = Nil
for(i <- 1 to 3) {
   val iGCS = i.toString.concat(suffix)
   i_grams_Cols = i_grams_Cols ::: List(iGCS)
}     
// Generate data for checking against later from via rows only and thus not via columns, positional dependency counts, hence permutations. 
val dfPostsNGrams = buildNgrams().fit(dfPosts).transform(dfPosts)

val dummySchema = StructType(
    StructField("phrase", StringType, true) :: Nil)
var dfPostsNGrams2 = spark.createDataFrame(sc.emptyRDD[Row], dummySchema)
for (i <- i_grams_Cols) {
  val nameCol = col({i})
  dfPostsNGrams2 = dfPostsNGrams2.union (dfPostsNGrams.select(explode({nameCol}).as("phrase")).toDF )
 }

val dfPostsNGrams3     = dfPostsNGrams2.withColumn("lower_phrase_concatenated",lower(regexp_replace(dfPostsNGrams2("phrase"), " ", "^"))) 

val result = dfPostsNGrams3.join(dfWords, col("lower_phrase_concatenated") === 
col("lower_words_perm"), "inner")  
              .groupBy("words_perm", "words")
              .agg(count("*").as("match_count"))

result.select("words", "match_count").show(false)

returns:

+--------------+-----------+
|words         |match_count|
+--------------+-----------+
|spark scala   |2          |
|users         |4          |
|abc           |2          |
|Stack Overflow|1          |
|xy            |1          |
|Stack         |3          |
|xyz           |1          |
+--------------+-----------+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM