简体   繁体   English

将熵计算从 Scala Spark 转换为 PySpark

[英]Converting entropy calculation from Scala Spark to PySpark

Environment: Spark 2.4.4环境:Spark 2.4.4

I'm trying to convert the following code from Scala Spark to PySpark:我正在尝试将以下代码从 Scala Spark 转换为 PySpark:

test.registerTempTable("test")

val df = sqlContext.sql("select cluster as _1, count(*) as _2 from test group by cluster, label order by cluster desc")

import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy("_1").rowsBetween(Long.MinValue, Long.MaxValue)

import org.apache.spark.sql.functions.sum

val p = $"_2" / sum($"_2").over(w)
val withP = df.withColumn("p", p)

import org.apache.spark.sql.functions.log2

val result = withP.groupBy($"_1").agg((-sum($"p" * log2($"p"))).alias("entropy"))

result.collect()

It is working and outputs the desired result:它正在工作并输出所需的结果:

Array[org.apache.spark.sql.Row] = Array([179,0.1091158547868134], [178,0.181873874177682], [177,-0.0], [176,0.9182958340544896], [175,-0.0], [174,-0.0], [173,0.04848740692447222], [172,-0.0], [171,-0.0], [170,-0.0], [169,-...

The PySpark version works up to the very final, but then results in an AnalysisException : PySpark 版本一直工作到最后,但会导致AnalysisException

df = sqlContext.sql("select cluster as _1, count(*) as _2 from test group by cluster, label order by cluster desc")

from pyspark.sql import Window

w = Window.partitionBy("_1").rowsBetween(-9223372036854775808L, 9223372036854775807L)

from pyspark.sql.functions import sum 

p = df['_2'] / sum(df['_2']).over(w)
withP = df.withColumn("p", p)

from pyspark.sql.functions import log2 

result = withP.groupBy("_1").agg((-sum(p * log2(p))).alias("entropy"))

The exception:例外:

Fail to execute line 19: result = withP.groupBy("_1").agg(sum(p * log2(p)).alias("entropy"))
Traceback (most recent call last):
  File "/tmp/zeppelin_pyspark-6317327282796051870.py", line 380, in <module>
    exec(code, _zcUserQueryNameSpace)
  File "<stdin>", line 19, in <module>
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/group.py", line 115, in agg
    _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/utils.py", line 69, in deco
    raise AnalysisException(s.split(': ', 1)[1], stackTrace)
AnalysisException: u'It is not allowed to use a window function inside an aggregate function. Please use the inner window function in a sub-query.;'

A sample of the original DataFrame:原始 DataFrame 的示例:

df = spark.createDataFrame([(1, 10), (1, 1), (2, 10), (3, 1), (3, 100)])

Why does the Scala version works but the Pyspark version, with exactly the same logic, doesn't?为什么 Scala 版本可以工作,而逻辑完全相同的 Pyspark 版本却不能?

It's conflicting between the column name p and column object p .列名p和列对象p之间存在冲突。 You should be using col("p") inside sum aggregation.您应该在 sum 聚合中使用col("p") This should work fine:这应该可以正常工作:

result = withP.groupBy("_1").agg((-sum(col("p") * log2(col("p")))).alias("entropy"))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM