Given two datasets of S and R both with a time column (t) as described below:
//snapshot with id at t
case class S(id: String, t: Int)
//reference data at t
case class R(t: Int, fk: String)
//Example test case
val ss: Dataset[S] = Seq(S("a", 1), S("a", 3), S("b", 5), S("b", 7))
.toDS
val rs: Dataset[R] = Seq(R(0, "a"), R(2, "a"), R(6, "b"))
.toDS
val srs: Dataset[(S, Option[R])] = ss
.asOfJoin(rs)
srs.collect() must contain theSameElementsAs
Seq((S("a", 1), Some(R(0, "a"))), (S("a", 3), Some(R(2, "a"))), (S("b", 5), None), (S("b", 7), Some(R(6, "b"))))
Goal is to find the most recent row in R that matches E's id if possible ie R can be optional in the output.
asOfJoin
is defined as below:
implicit class SOps(ss: Dataset[S]) {
def asOfJoin(rs: Dataset[R])(implicit spark: SparkSession): Dataset[(S, Option[R])] = ???
}
One solution using Dataset API is as follows:
def asOfJoin(rs: Dataset[R])(implicit spark: SparkSession): Dataset[(S, Option[R])] = {
import spark.implicits._
ss
.joinWith(
rs,
ss("id") === rs("fk") && ss("t") >= rs("t"),
"left_outer")
.map { case (l, r) => (l, Option(r)) }
.groupByKey { case (s, _) => s }
.reduceGroups { (x, y) =>
(x, y) match {
case ((_, Some(R(tx, _))), (_, Some(R(ty, _)))) => if (tx > ty) x else y
case _ => x
}
}
.map { case (_, r) => r }
}
I'm not sure about the size of the dataset S and dataset R. But from your codes, I can see that the efficiency of the join(with unequal expressions) is bad, and I can give some suggestions based on different specific scenarios:
Either Dataset R or Dataset S doesn't have too much data.
I suggest that you can broadcast the smaller dataset and finish the business logic in a spark udf with the help of broadcast variable. In this way, you don't need the shuffle(join) process, which helps you save a lot of time and resources.
For every unique id, count(distinct t) is not big.
I suggest that you can do a pre-aggregation by grouping id and collect_set(t) like this:
select id,collect_set(t) as t_set from S
In this way, you can remove the unequal expression(ss("t") >= rs("t")) in the join. And write your business logic with two t_sets from dataset S and dataset R.
For other scenarios:
I suggest that you optimize your codes with a equal join and a window function. Since I'm more familiar with SQL, I write SQL here, which can be transformed to dataset API:
select
sid,
st,
rt
from
(
select
S.id as sid,
S.t as st,
R.t as rt,
row_number() over (partition by S.id order by (S.t - NVL(R.t, 0)) rn
from
S
left join R on S.id = R.fk) tbl
where tbl.rn = 1
I took @bupt_ljy 's comment about avoiding a theta join and following seems to scale really well:
def asOfJoin(rs: Dataset[R])(implicit spark: SparkSession): Dataset[(S, Option[R])] = {
import spark.implicits._
ss
.joinWith(
rs.sort(rs("fk"), rs("t")),
ss("id") === rs("fk"),
"left_outer")
.map { case (l, r) => (l, Option(r)) }
.groupByKey { case (s, _) => s }
.flatMapGroups { (k, vs) =>
new Iterator[(S, Option[R])] {
private var didNotStart: Boolean = true
override def hasNext: Boolean = didNotStart
override def next(): (S, Option[R]) = {
didNotStart = false
vs
.find { case (l, rOpt) =>
rOpt match {
case Some(r) => l.t >= r.t
case _ => false
}
}.getOrElse((k, None))
}
}
}
}
However, still super imperative code and there must be a better way...
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.