[英]Converting entropy calculation from Scala Spark to PySpark
Environment: Spark 2.4.4环境:Spark 2.4.4
I'm trying to convert the following code from Scala Spark to PySpark:我正在尝试将以下代码从 Scala Spark 转换为 PySpark:
test.registerTempTable("test")
val df = sqlContext.sql("select cluster as _1, count(*) as _2 from test group by cluster, label order by cluster desc")
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy("_1").rowsBetween(Long.MinValue, Long.MaxValue)
import org.apache.spark.sql.functions.sum
val p = $"_2" / sum($"_2").over(w)
val withP = df.withColumn("p", p)
import org.apache.spark.sql.functions.log2
val result = withP.groupBy($"_1").agg((-sum($"p" * log2($"p"))).alias("entropy"))
result.collect()
It is working and outputs the desired result:它正在工作并输出所需的结果:
Array[org.apache.spark.sql.Row] = Array([179,0.1091158547868134], [178,0.181873874177682], [177,-0.0], [176,0.9182958340544896], [175,-0.0], [174,-0.0], [173,0.04848740692447222], [172,-0.0], [171,-0.0], [170,-0.0], [169,-...
The PySpark version works up to the very final, but then results in an AnalysisException
: PySpark 版本一直工作到最后,但会导致AnalysisException
:
df = sqlContext.sql("select cluster as _1, count(*) as _2 from test group by cluster, label order by cluster desc")
from pyspark.sql import Window
w = Window.partitionBy("_1").rowsBetween(-9223372036854775808L, 9223372036854775807L)
from pyspark.sql.functions import sum
p = df['_2'] / sum(df['_2']).over(w)
withP = df.withColumn("p", p)
from pyspark.sql.functions import log2
result = withP.groupBy("_1").agg((-sum(p * log2(p))).alias("entropy"))
The exception:例外:
Fail to execute line 19: result = withP.groupBy("_1").agg(sum(p * log2(p)).alias("entropy"))
Traceback (most recent call last):
File "/tmp/zeppelin_pyspark-6317327282796051870.py", line 380, in <module>
exec(code, _zcUserQueryNameSpace)
File "<stdin>", line 19, in <module>
File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/group.py", line 115, in agg
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
answer, self.gateway_client, self.target_id, self.name)
File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/utils.py", line 69, in deco
raise AnalysisException(s.split(': ', 1)[1], stackTrace)
AnalysisException: u'It is not allowed to use a window function inside an aggregate function. Please use the inner window function in a sub-query.;'
A sample of the original DataFrame:原始 DataFrame 的示例:
df = spark.createDataFrame([(1, 10), (1, 1), (2, 10), (3, 1), (3, 100)])
Why does the Scala version works but the Pyspark version, with exactly the same logic, doesn't?为什么 Scala 版本可以工作,而逻辑完全相同的 Pyspark 版本却不能?
It's conflicting between the column name p
and column object p
.列名p
和列对象p
之间存在冲突。 You should be using col("p")
inside sum aggregation.您应该在 sum 聚合中使用col("p")
。 This should work fine:这应该可以正常工作:
result = withP.groupBy("_1").agg((-sum(col("p") * log2(col("p")))).alias("entropy"))
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.