简体   繁体   中英

How to calculate the number of duplicated elements in a nested list in PySpark?

I have the following DataFrame in PySpark:

+----------+------------------------+
|        id|              codes_list|
+----------+------------------------+
|      FF10|   [[1049, 1683], [108]]|
|      AB36|        [[1507], [1005]]|
|      8266|[[1049], [1049], [1049]]|
+----------+------------------------+

This is the schema:

root
 |-- id: string (nullable = true)
 |-- codes_list: array (nullable = true)
 |    |-- element: string (containsNull = true)

How can I calculate the number of duplicated codes in codes_list ?

This is the expected result:

+----------+----+
|        id| qty|
+----------+----+
|      FF10|   0|
|      AB36|   0|
|      8266|   1|
+----------+----+

One simple way is to explode those numbers and count each occurrence of id , code . Then group by id and use conditional sum to get the number of duplicated values.

As the array contains strings and not sub-arrays, you could, first, remove the square brackets and split by , to get the codes.

data = [("FF10", ["[1049, 1683]", "[108]"]),
        ("FAB36", ["[1507]", "[1005]"]),
        ("8266", ["[1049]", "[1049]", "[1049]"])]

df = spark.createDataFrame(data, ["id", "codes_list"])


df.withColumn("codes", explode("codes_list")) \
  .withColumn("codes", explode(split(regexp_replace("codes", "[\\[\\]]", ""), ","))) \
  .groupBy("id", "codes").count() \
  .groupBy("id").agg(sum((col("count") > lit(1)).cast("int")).alias("qty")) \
  .show()

Gives:

+-----+---+
|   id|qty|
+-----+---+
| FF10|  0|
| 8266|  1|
|FAB36|  0|
+-----+---+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM