简体   繁体   中英

Add distinct count of a column to each row in PySpark

I need to add distinct count of a column to each row in PySpark dataframe.

Example: If the original dataframe is this:

+----+----+
|col1|col2|
+----+----+
|abc |   1|
|xyz |   1|
|dgc |   2|
|ydh |   3|
|ujd |   1|
|ujx |   3|
+----+----+

Then I want something like this:

+----+----+----+
|col1|col2|col3|
+----+----+----+
|abc |   1|   3|
|xyz |   1|   3|
|dgc |   2|   3|
|ydh |   3|   3|
|ujd |   1|   3|
|ujx |   3|   3|
+----+----+----+

I tried df.withColumn('total_count', f.countDistinct('col2')) but it's giving error.

You can count distinct elements in the column and create new column with the value:

distincts = df.dropDuplicates(["col2"]).count()
df = df.withColumn("col3", f.lit(distincts))

Cross join to the count distinct as below:

df2 = df.crossJoin(df.select(F.countDistinct('col2').alias('col3')))

df2.show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
| abc|   1|   3|
| xyz|   1|   3|
| dgc|   2|   3|
| ydh|   3|   3|
| ujd|   1|   3|
| ujx|   3|   3|
+----+----+----+

You can use Window , collect_set and size :

from pyspark.sql import functions as F, Window

df = spark.createDataFrame([("abc", 1), ("xyz", 1), ("dgc", 2), ("ydh", 3), ("ujd", 1), ("ujx", 3)], ['col1', 'col2'])

window = Window.orderBy("col2").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

df.withColumn("col3", F.size(F.collect_set(F.col("col2")).over(window))).show()

+----+----+----+
|col1|col2|col3|
+----+----+----+
| abc|   1|   3|
| xyz|   1|   3|
| dgc|   2|   3|
| ydh|   3|   3|
| ujd|   1|   3|
| ujx|   3|   3|
+----+----+----+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM