簡體   English   中英

將組計數列添加到 PySpark 數據框

[英]Adding a group count column to a PySpark dataframe

由於其出色的 Spark 處理能力,我從 R 和tidyverse轉到 PySpark,並且我正在努力將某些概念從一個上下文映射到另一個上下文。

特別是,假設我有一個如下所示的數據集

x | y
--+--
a | 5
a | 8
a | 7
b | 1

我想添加一個包含每個x值的行數的列,如下所示:

x | y | n
--+---+---
a | 5 | 3
a | 8 | 3
a | 7 | 3
b | 1 | 1

在 dplyr 中,我只想說:

import(tidyverse)

df <- read_csv("...")
df %>%
    group_by(x) %>%
    mutate(n = n()) %>%
    ungroup()

就是這樣。 如果我想按行數進行總結,我可以在 PySpark 中做一些幾乎同樣簡單的事情:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.getOrCreate()

spark.read.csv("...") \
    .groupBy(col("x")) \
    .count() \
    .show()

我想我明白withColumn相當於 dplyr 的mutate 然而,當我這樣做,PySpark告訴我, withColumn無定義groupBy數據:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count

spark = SparkSession.builder.getOrCreate()

spark.read.csv("...") \
    .groupBy(col("x")) \
    .withColumn("n", count("x")) \
    .show()

在短期內,我可以簡單地創建包含計數的第二個數據幀並將其連接到原始數據幀。 但是,在大表的情況下,這似乎會變得效率低下。 實現此目的的規范方法是什么?

執行groupBy() ,必須先指定聚合,然后才能顯示結果。 例如:

import pyspark.sql.functions as f
data = [
    ('a', 5),
    ('a', 8),
    ('a', 7),
    ('b', 1),
]
df = sqlCtx.createDataFrame(data, ["x", "y"])
df.groupBy('x').count().select('x', f.col('count').alias('n')).show()
#+---+---+
#|  x|  n|
#+---+---+
#|  b|  1|
#|  a|  3|
#+---+---+

這里我使用alias()來重命名列。 但這每組只返回一行。 如果您想要附加計數的所有行,您可以使用Window執行此操作:

from pyspark.sql import Window
w = Window.partitionBy('x')
df.select('x', 'y', f.count('x').over(w).alias('n')).sort('x', 'y').show()
#+---+---+---+
#|  x|  y|  n|
#+---+---+---+
#|  a|  5|  3|
#|  a|  7|  3|
#|  a|  8|  3|
#|  b|  1|  1|
#+---+---+---+

或者,如果您更熟悉 SQL,則可以將數據pyspark-sql注冊為臨時表並利用pyspark-sql來執行相同的操作:

df.registerTempTable('table')
sqlCtx.sql(
    'SELECT x, y, COUNT(x) OVER (PARTITION BY x) AS n FROM table ORDER BY x, y'
).show()
#+---+---+---+
#|  x|  y|  n|
#+---+---+---+
#|  a|  5|  3|
#|  a|  7|  3|
#|  a|  8|  3|
#|  b|  1|  1|
#+---+---+---+

我發現我們可以更接近 tidyverse 示例:

from pyspark.sql import Window
w = Window.partitionBy('x')
df.withColumn('n', f.count('x').over(w)).sort('x', 'y').show()

作為@pault 附錄

import pyspark.sql.functions as F

...

(df
.groupBy(F.col('x'))
.agg(F.count('x').alias('n'))
.show())

#+---+---+
#|  x|  n|
#+---+---+
#|  b|  1|
#|  a|  3|
#+---+---+

享受

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM