简体   繁体   中英

pySpark array<string> != string;

I'm trying to extract from dataframe rows that contains words from list: below I'm pasting my code:

from pyspark.ml.feature import Tokenizer, RegexTokenizer
from pyspark.sql.functions import col, udf
from pyspark.sql.types import IntegerType
from pyspark.sql.types import IntegerType
import findspark
findspark.init()
conf = SparkConf().setAppName("nlp").setMaster("spark://192.168.1.51:7077")\
                  .set("spark.jars","/home/artur/elasticsearch-spark-20_2.11-7.8.1.jar")\
                  .set("spark.sql.repl.eagerEval.enabled",True)\
                  .set("spark.cores.max","8")\
                  .set("spark.executor.memory", "2g")\
                  .set("spark.driver.memory", "2g")\
                  .set("spark.sql.repl.eagerEval.enabled",True)
sc = SparkContext(conf=conf).getOrCreate()
sqlContext = SQLContext(sc)
spark = SparkSession(sc)
sentenceDataFrame = spark.createDataFrame([
    (0, ["Hi", "I" ,"heard" ,"about", "Spark"]),
    (1, ["I", "wish" ,"Java", "could", "use", "case", "classes"]),
    (2, ["Logistic","regression","models","are","neat"])
], ["id", "sentence"])
my_list = ['Java','Spark']
selected_df = sentenceDataFrame.filter(sentenceDataFrame.sentence.isin(my_list))

and im getting following error:

Py4JJavaError: An error occurred while calling o262.filter.
: org.apache.spark.sql.AnalysisException: cannot resolve '(`sentence` IN ('Java', 'Spark'))' due to data type mismatch: Arguments must be same type but were: array<string> != string;;
'Filter sentence#27 IN (Java,Spark)
+- LogicalRDD [id#26L, sentence#27], false

Please give me advice how to solve this issue

solutions depend on your spark version :

Spark 2.4+

from pyspark.sql import functions as F

sentenceDataFrame.filter(
    F.size(
        F.array_intersect(
            F.col("sentence"), F.array(*(F.lit(item) for item in my_list))
        )
    )
    > 0
).show()

+---+--------------------+
| id|            sentence|
+---+--------------------+
|  0|[Hi, I, heard, ab...|
|  1|[I, wish, Java, c...|
+---+--------------------+

version 2.3 or below

# using broadcast join, fast but my_list needs to be small
keyword_df = spark.createDataFrame([(keyword,) for keyword in my_list], ["keyword"])

sentenceDataFrame.join(
    F.broadcast(keyword_df),
    how="leftsemi",
    on=F.expr("array_contains(sentence, keyword)"),
).show()
# Using UDF, slowest solution
from pyspark.sql import functions as F, types as T

def intersect(keyword_list):
    @F.udf(T.BooleanType())
    def udf(sentence):
        intersection = set(sentence).intersection(set(keyword_list))
        return True if intersection else False

    return udf


sentenceDataFrame.filter(intersect(my_list)(sentenceDataFrame.sentence)).show()

The variable for .isin is not a list, but the strings. So, you can use the array_intersect that is usable after the spark 2.4.0+.

For example,

from pyspark.sql.functions import array_intersect, array, lit

sentenceDataFrame = spark.createDataFrame([
    (0, ["Hi", "I" ,"heard" ,"about", "Spark"]),
    (1, ["I", "wish" ,"Java", "could", "use", "case", "classes"]),
    (2, ["Logistic","regression","models","are","neat"])
], ["id", "sentence"])

my_list = ['Java','Spark']

selected_df = sentenceDataFrame.filter(array_intersect(sentenceDataFrame.sentence, array(*map(lambda x: lit(x), my_list))) != array())
selected_df.show(10, False)

+---+------------------------------------------+
|id |sentence                                  |
+---+------------------------------------------+
|0  |[Hi, I, heard, about, Spark]              |
|1  |[I, wish, Java, could, use, case, classes]|
+---+------------------------------------------+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM