繁体   English   中英

如何在 pyspark DataFrame 上快速计算不同条件下的多个计数?

[英]How to compute multiple counts with different conditions on a pyspark DataFrame, fast?

假设我有这个 pyspark 数据框:

data = spark.createDataFrame(schema=['Country'], data=[('AT',), ('BE',), ('France',), ('Latvia',)])

假设我想收集有关此数据的各种统计信息。 例如,我可能想知道有多少行使用 2 个字符的国家/地区代码,有多少行使用更长的国家/地区名称:

count_short = data.where(F.length(F.col('Country')) == 2).count()
count_long = data.where(F.length(F.col('Country')) > 2).count()

这是有效的,但是当我想根据不同条件收集许多不同的计数时,即使对于很小的数据集,它也会变得非常慢。 在我工作的 Azure Synapse Studio 中,每次计数需要 1-2 秒来计算.

我需要进行 100 次以上的计数,计算 10 行的数据集需要几分钟的时间。 在有人问之前,这些计数的条件比我的例子更复杂。 我不能按长度分组或做其他类似的技巧。

我正在寻找一种在任意条件下快速进行多次计数的通用方法。

猜测性能缓慢的原因是对于每次计数调用,我的 pyspark 笔记本都会启动一些具有显着开销的 Spark 进程。 所以我假设如果有某种方法可以在单个查询中收集这些计数,我的性能问题将得到解决。

我想到的一个可能的解决方案是建立一个临时列,指示我的哪些条件已匹配,然后对其调用countDistinct 但是我会对条件匹配的所有组合进行单独计数。 我还注意到,根据情况,在计算统计数据之前执行data = data.localCheckpoint()时性能会好一些,但一般问题仍然存在。

有没有更好的办法?

函数“count”可以用条件(Scala)替换为“sum”:

data.select(
  sum(
    when(length(col("Country")) === 2, 1).otherwise(0)
  ).alias("two_characters"),
  sum(
    when(length(col("Country")) > 2, 1).otherwise(0)
  ).alias("more_than_two_characters")
)

一种方法是将多个查询合并为一个,另一种方法是缓存一次又一次查询的数据帧。 通过缓存数据帧,我们避免了每次调用 count() 时重新评估。

data.cache()

要记住的事情很少。 如果您在数据帧上应用多个操作并且有很多转换并且您正在从某个外部源读取该数据,那么您绝对应该在对该数据帧应用任何单个操作之前缓存该数据帧。

@pasha701 提供的答案有效,但您必须根据要分析的不同国家/地区代码长度值继续添加列。

您可以使用以下代码在一个数据框中获取不同国家/地区代码的计数。

//import statements
from pyspark.sql.functions import *
//sample Dataframe
data = spark.createDataFrame(schema=['Country'], data=[('AT',), ('ACE',), ('BE',), ('France',), ('Latvia',)])
//adding additional column that gives the length of the country codes
data1 = data.withColumn("CountryLength",length(col('Country')))
//creating columns list schema for the final output
outputcolumns = ["CountryLength","RecordsCount"]
//selecting the countrylength column and converting that to rdd and performing map reduce operation to count the occurrences of the same length 
countrieslength = data1.select("CountryLength").rdd.map(lambda word: (word, 1)).reduceByKey(lambda a,b:a +b).toDF(outputcolumns).select("CountryLength.CountryLength","RecordsCount")
//now you can do display or show on the dataframe to see the output
display(countrieslength)

请查看您可能获得的输出快照,如下所示: 在此处输入图片说明

如果要在此数据帧上应用多个过滤条件,则可以缓存此数据帧并根据国家/地区代码长度获取不同记录组合的计数。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM