简体   繁体   中英

Converting entropy calculation from Scala Spark to PySpark

Environment: Spark 2.4.4

I'm trying to convert the following code from Scala Spark to PySpark:

test.registerTempTable("test")

val df = sqlContext.sql("select cluster as _1, count(*) as _2 from test group by cluster, label order by cluster desc")

import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy("_1").rowsBetween(Long.MinValue, Long.MaxValue)

import org.apache.spark.sql.functions.sum

val p = $"_2" / sum($"_2").over(w)
val withP = df.withColumn("p", p)

import org.apache.spark.sql.functions.log2

val result = withP.groupBy($"_1").agg((-sum($"p" * log2($"p"))).alias("entropy"))

result.collect()

It is working and outputs the desired result:

Array[org.apache.spark.sql.Row] = Array([179,0.1091158547868134], [178,0.181873874177682], [177,-0.0], [176,0.9182958340544896], [175,-0.0], [174,-0.0], [173,0.04848740692447222], [172,-0.0], [171,-0.0], [170,-0.0], [169,-...

The PySpark version works up to the very final, but then results in an AnalysisException :

df = sqlContext.sql("select cluster as _1, count(*) as _2 from test group by cluster, label order by cluster desc")

from pyspark.sql import Window

w = Window.partitionBy("_1").rowsBetween(-9223372036854775808L, 9223372036854775807L)

from pyspark.sql.functions import sum 

p = df['_2'] / sum(df['_2']).over(w)
withP = df.withColumn("p", p)

from pyspark.sql.functions import log2 

result = withP.groupBy("_1").agg((-sum(p * log2(p))).alias("entropy"))

The exception:

Fail to execute line 19: result = withP.groupBy("_1").agg(sum(p * log2(p)).alias("entropy"))
Traceback (most recent call last):
  File "/tmp/zeppelin_pyspark-6317327282796051870.py", line 380, in <module>
    exec(code, _zcUserQueryNameSpace)
  File "<stdin>", line 19, in <module>
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/group.py", line 115, in agg
    _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/utils.py", line 69, in deco
    raise AnalysisException(s.split(': ', 1)[1], stackTrace)
AnalysisException: u'It is not allowed to use a window function inside an aggregate function. Please use the inner window function in a sub-query.;'

A sample of the original DataFrame:

df = spark.createDataFrame([(1, 10), (1, 1), (2, 10), (3, 1), (3, 100)])

Why does the Scala version works but the Pyspark version, with exactly the same logic, doesn't?

It's conflicting between the column name p and column object p . You should be using col("p") inside sum aggregation. This should work fine:

result = withP.groupBy("_1").agg((-sum(col("p") * log2(col("p")))).alias("entropy"))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM