简体   繁体   中英

counting rows of a dataframe with condition in spark

I am trying this:

df=dfFromJson:
{"class":"name 1","stream":"science"}
{"class":"name 1","stream":"arts"}
{"class":"name 1","stream":"science"}
{"class":"name 1","stream":"law"}
{"class":"name 1","stream":"law"}
{"class":"name 2","stream":"science"}
{"class":"name 2","stream":"arts"}
{"class":"name 2","stream":"law"}
{"class":"name 2","stream":"science"}
{"class":"name 2","stream":"arts"}
{"class":"name 2","stream":"law"}


df.groupBy("class").agg(count(col("stream")==="science") as "stream_science", count(col("stream")==="arts") as "stream_arts", count(col("stream")==="law") as "stream_law")

This is not giving expected output, how can I achieve it in fastest way?

It is not exactly clear what is the expected output but I guess you want something like this:

import org.apache.spark.sql.functions.{count, col, when}

val streams = df.select($"stream").distinct.collect.map(_.getString(0))
val exprs = streams.map(s => count(when($"stream" === s, 1)).alias(s"stream_$s"))

df
  .groupBy("class")
  .agg(exprs.head, exprs.tail: _*)

// +------+--------------+----------+-----------+
// | class|stream_science|stream_law|stream_arts|
// +------+--------------+----------+-----------+
// |name 1|             2|         2|          1|
// |name 2|             2|         2|          2|
// +------+--------------+----------+-----------+

If you don't care about names and have only one group column you can simply use DataFrameStatFunctions.crosstab :

df.stat.crosstab("class", "stream")

// +------------+---+----+-------+
// |class_stream|law|arts|science|
// +------------+---+----+-------+
// |      name 1|  2|   1|      2|
// |      name 2|  2|   2|      2|
// +------------+---+----+-------+

You can just group by both the columns instead of grouping by a single column and then filtering. Because I am not fluent enough in Scala, below is the code snippet in Python. Note, I have changed your col names from "stream" and "class" to "dept" and "name" to avoid name conflicts with Spark's "stream" and "class" types.

import pyspark.sql
from pyspark.sql import Row

hc = HiveContext(sc)

obj = [
    {"class":"name 1","stream":"science"},
    {"class":"name 1","stream":"arts"}
    {"class":"name 1","stream":"science"},
    {"class":"name 1","stream":"law"},
    {"class":"name 1","stream":"law"},
    {"class":"name 2","stream":"science"},
    {"class":"name 2","stream":"arts"},
    {"class":"name 2","stream":"law"},
    {"class":"name 2","stream":"science"},
    {"class":"name 2","stream":"arts"},
    {"class":"name 2","stream":"law"}
]
rdd = sc.parallelize(obj).map(labmda i: Row(dept=i['stream'], name=i['class']))
df = hc.createDataFrame(rdd)
df.groupby(df.dept, df.name).count().collect()

This results in the following output -

[
    Row(dept='science', name='name 1', count=2), 
    Row(dept='science', name='name 2', count=2), 
    Row(dept='arts', name='name 1', count=1), 
    Row(dept='arts', name='name 2', count=2), 
    Row(dept='law', name='name 1', count=2), 
    Row(dept='law', name='name 2', count=2)
]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM