简体   繁体   English

Apache Spark 中的 DataFrame 相等性

[英]DataFrame equality in Apache Spark

Assume df1 and df2 are two DataFrame s in Apache Spark, computed using two different mechanisms, eg, Spark SQL vs. the Scala/Java/Python API.假设df1df2是 Apache Spark 中的两个DataFrame ,使用两种不同的机制计算,例如 Spark SQL 与 Scala/Java/Python API。

Is there an idiomatic way to determine whether the two data frames are equivalent (equal, isomorphic), where equivalence is determined by the data (column names and column values for each row) being identical save for the ordering of rows & columns?是否有一种惯用的方法来确定两个数据帧是否等价(相等,同构),其中等价性取决于数据(每行的列名和列值)是否相同,除了行和列的排序?

The motivation for the question is that there are often many ways to compute some big data result, each with its own trade-offs.提出这个问题的动机是,通常有很多方法可以计算一些大数据结果,每种方法都有自己的取舍。 As one explores these trade-offs, it is important to maintain correctness and hence the need to check for the equivalence/equality on a meaningful test data set.在探索这些权衡时,保持正确性很重要,因此需要在有意义的测试数据集上检查等价/相等性。

Scala (see below for PySpark) Scala(PySpark 见下文)

The spark-fast-tests library has two methods for making DataFrame comparisons (I'm the creator of the library): spark-fast-tests库有两种进行 DataFrame 比较的方法(我是该库的创建者):

The assertSmallDataFrameEquality method collects DataFrames on the driver node and makes the comparison assertSmallDataFrameEquality方法收集驱动节点上的 DataFrame 并进行比较

def assertSmallDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
  if (!actualDF.schema.equals(expectedDF.schema)) {
    throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
  }
  if (!actualDF.collect().sameElements(expectedDF.collect())) {
    throw new DataFrameContentMismatch(contentMismatchMessage(actualDF, expectedDF))
  }
}

The assertLargeDataFrameEquality method compares DataFrames spread on multiple machines (the code is basically copied from spark-testing-base ) assertLargeDataFrameEquality方法比较分布在多台机器上的 DataFrame(代码基本上是从spark-testing-base复制的)

def assertLargeDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
  if (!actualDF.schema.equals(expectedDF.schema)) {
    throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
  }
  try {
    actualDF.rdd.cache
    expectedDF.rdd.cache

    val actualCount = actualDF.rdd.count
    val expectedCount = expectedDF.rdd.count
    if (actualCount != expectedCount) {
      throw new DataFrameContentMismatch(countMismatchMessage(actualCount, expectedCount))
    }

    val expectedIndexValue = zipWithIndex(actualDF.rdd)
    val resultIndexValue = zipWithIndex(expectedDF.rdd)

    val unequalRDD = expectedIndexValue
      .join(resultIndexValue)
      .filter {
        case (idx, (r1, r2)) =>
          !(r1.equals(r2) || RowComparer.areRowsEqual(r1, r2, 0.0))
      }

    val maxUnequalRowsToShow = 10
    assertEmpty(unequalRDD.take(maxUnequalRowsToShow))

  } finally {
    actualDF.rdd.unpersist()
    expectedDF.rdd.unpersist()
  }
}

assertSmallDataFrameEquality is faster for small DataFrame comparisons and I've found it sufficient for my test suites. assertSmallDataFrameEquality对于小型 DataFrame 比较更快,我发现它对于我的测试套件来说已经足够了。

PySpark PySpark

Here's a simple function that returns true if the DataFrames are equal:这是一个简单的函数,如果 DataFrame 相等则返回 true:

def are_dfs_equal(df1, df2):
    if df1.schema != df2.schema:
        return False
    if df1.collect() != df2.collect():
        return False
    return True

or simplified或简化

def are_dfs_equal(df1, df2): 
    return (df1.schema == df2.schema) and (df1.collect() == df2.collect())

You'll typically perform DataFrame equality comparisons in a test suite and will want a descriptive error message when the comparisons fail (a True / False return value doesn't help much when debugging).您通常会在测试套件中执行 DataFrame 相等比较,并且在比较失败时需要描述性错误消息( True / False返回值在调试时没有多大帮助)。

Use the chispa library to access the assert_df_equality method that returns descriptive error messages for test suite workflows.使用chispa库访问assert_df_equality方法,该方法返回测试套件工作流的描述性错误消息。

There are some standard ways in the Apache Spark test suites, however most of these involve collecting the data locally and if you want to do equality testing on large DataFrames then that is likely not a suitable solution. Apache Spark 测试套件中有一些标准方法,但是其中大多数都涉及在本地收集数据,如果您想对大型 DataFrame 进行相等测试,那么这可能不是一个合适的解决方案。

Checking the schema first and then you could do an intersection to df3 and verify that the count of df1,df2 & df3 are all equal (however this only works if there aren't duplicate rows, if there are different duplicates rows this method could still return true).首先检查架构,然后您可以对 df3 进行交集并验证 df1、df2 和 df3 的计数是否全部相等(但是,这仅在没有重复行的情况下才有效,如果有不同的重复行,此方法仍然可以返回真)。

Another option would be getting the underlying RDDs of both of the DataFrames, mapping to (Row, 1), doing a reduceByKey to count the number of each Row, and then cogrouping the two resulting RDDs and then do a regular aggregate and return false if any of the iterators are not equal.另一种选择是获取两个 DataFrame 的底层 RDD,映射到 (Row, 1),执行 reduceByKey 以计算每行的数量,然后将两个结果 RDD 组合在一起,然后进行常规聚合并返回 false任何迭代器都不相等。

I don't know about idiomatic, but I think you can get a robust way to compare DataFrames as you describe as follows.我不了解惯用语,但我认为您可以获得一种比较 DataFrames 的可靠方法,如下所述。 (I'm using PySpark for illustration, but the approach carries across languages.) (我使用 PySpark 进行说明,但该方法可以跨语言使用。)

a = spark.range(5)
b = spark.range(5)

a_prime = a.groupBy(sorted(a.columns)).count()
b_prime = b.groupBy(sorted(b.columns)).count()

assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0

This approach correctly handles cases where the DataFrames may have duplicate rows, rows in different orders, and/or columns in different orders.这种方法可以正确处理 DataFrame 可能具有重复行、不同顺序的行和/或不同顺序的列的情况。

For example:例如:

a = spark.createDataFrame([('nick', 30), ('bob', 40)], ['name', 'age'])
b = spark.createDataFrame([(40, 'bob'), (30, 'nick')], ['age', 'name'])
c = spark.createDataFrame([('nick', 30), ('bob', 40), ('nick', 30)], ['name', 'age'])

a_prime = a.groupBy(sorted(a.columns)).count()
b_prime = b.groupBy(sorted(b.columns)).count()
c_prime = c.groupBy(sorted(c.columns)).count()

assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0
assert a_prime.subtract(c_prime).count() != 0

This approach is quite expensive, but most of the expense is unavoidable given the need to perform a full diff.这种方法非常昂贵,但考虑到需要执行完整的差异,大部分费用是不可避免的。 And this should scale fine as it doesn't require collecting anything locally.这应该可以很好地扩展,因为它不需要在本地收集任何东西。 If you relax the constraint that the comparison should account for duplicate rows, then you can drop the groupBy() and just do the subtract() , which would probably speed things up notably.如果您放宽比较应该考虑重复行的约束,那么您可以删除groupBy()并只执行subtract() ,这可能会显着加快速度。

爪哇:

assert resultDs.union(answerDs).distinct().count() == resultDs.intersect(answerDs).count();

尝试执行以下操作:

df1.except(df2).isEmpty

A scalable and easy way is to diff the two DataFrame s and count the non-matching rows:一种可扩展且简单的方法是区分两个DataFrame并计算不匹配的行:

df1.diff(df2).where($"diff" != "N").count

If that number is not zero, then the two DataFrame s are not equivalent.如果该数字不为零,则两个DataFrame不相等。

The diff transformation is provided by spark-extension . diff转换由spark-extension提供。

It identifies I nserted, C hanged, D eleted and u N -changed rows.它标识了I插入、 C挂起、 D删除和 u N更改的行。

There are 4 Options depending on whether you have duplicate rows or not.根据您是否有重复的行,有 4 个选项。

Let's say we have two DataFrame s, z1 and z1.假设我们有两个DataFrame ,z1 和 z1。 Option 1/2 are good for rows without duplicates.选项 1/2 适用于没有重复的行。 You can try these in spark-shell .您可以在spark-shell中尝试这些。

  • Option 1: do except directly选项 1:直接做 except
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Column

def isEqual(left: DataFrame, right: DataFrame): Boolean = {
   if(left.columns.length != right.columns.length) return false // column lengths don't match
   if(left.count != right.count) return false // record count don't match
   return left.except(right).isEmpty && right.except(left).isEmpty
}
  • Option 2: generate row hash by columns选项 2:按列生成行哈希
def createHashColumn(df: DataFrame) : Column = {
   val colArr = df.columns
   md5(concat_ws("", (colArr.map(col(_))) : _*))
}

val z1SigDF = z1.select(col("index"), createHashColumn(z1).as("signature_z1"))
val z2SigDF = z2.select(col("index"), createHashColumn(z2).as("signature_z2"))
val joinDF = z1SigDF.join(z2SigDF, z1SigDF("index") === z2SigDF("index")).where($"signature_z1" =!= $"signature_z2").cache
// should be 0
joinDF.count
  • Option 3: use GroupBy (for DataFrame with duplicate rows)选项 3:使用GroupBy (用于具有重复行的 DataFrame)
val z1Grouped = z1.groupBy(z1.columns.map(c => z1(c)).toSeq : _*).count().withColumnRenamed("count", "recordRepeatCount")
val z2Grouped = z2.groupBy(z2.columns.map(c => z2(c)).toSeq : _*).count().withColumnRenamed("count", "recordRepeatCount")

val inZ1NotInZ2 = z1Grouped.except(z2Grouped).toDF()
val inZ2NotInZ1 = z2Grouped.except(z1Grouped).toDF()
// both should be size 0
inZ1NotInZ2.show
inZ2NotInZ1.show
  • Option 4, use exceptAll , which should also work for data with duplicate rows选项 4,使用exceptAll ,它也适用于具有重复行的数据
// Source Code: https://github.com/apache/spark/blob/50538600ec/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L2029
val inZ1NotInZ2 = z1.exceptAll(z2).toDF()
val inZ2NotInZ1 = z2.exceptAll(z1).toDF()
// same here, // both should be size 0
inZ1NotInZ2.show
inZ2NotInZ1.show

You can do this using a little bit of deduplication in combination with a full outer join.您可以使用一点重复数据删除和完全外部连接来完成此操作。 The advantage of this approach is that it does not require you to collect results to the driver, and that it avoids running multiple jobs.这种方法的优点是它不需要您将结果收集到驱动程序,并且它避免了运行多个作业。

import org.apache.spark.sql._
import org.apache.spark.sql.functions._

// Generate some random data.
def random(n: Int, s: Long) = {
  spark.range(n).select(
    (rand(s) * 10000).cast("int").as("a"),
    (rand(s + 5) * 1000).cast("int").as("b"))
}
val df1 = random(10000000, 34)
val df2 = random(10000000, 17)

// Move all the keys into a struct (to make handling nulls easy), deduplicate the given dataset
// and count the rows per key.
def dedup(df: Dataset[Row]): Dataset[Row] = {
  df.select(struct(df.columns.map(col): _*).as("key"))
    .groupBy($"key")
    .agg(count(lit(1)).as("row_count"))
}

// Deduplicate the inputs and join them using a full outer join. The result can contain
// the following things:
// 1. Both keys are not null (and thus equal), and the row counts are the same. The dataset
//    is the same for the given key.
// 2. Both keys are not null (and thus equal), and the row counts are not the same. The dataset
//    contains the same keys.
// 3. Only the right key is not null.
// 4. Only the left key is not null.
val joined = dedup(df1).as("l").join(dedup(df2).as("r"), $"l.key" === $"r.key", "full")

// Summarize the differences.
val summary = joined.select(
  count(when($"l.key".isNotNull && $"r.key".isNotNull && $"r.row_count" === $"l.row_count", 1)).as("left_right_same_rc"),
  count(when($"l.key".isNotNull && $"r.key".isNotNull && $"r.row_count" =!= $"l.row_count", 1)).as("left_right_different_rc"),
  count(when($"l.key".isNotNull && $"r.key".isNull, 1)).as("left_only"),
  count(when($"l.key".isNull && $"r.key".isNotNull, 1)).as("right_only"))
summary.show()
try {
  return ds1.union(ds2)
          .groupBy(columns(ds1, ds1.columns()))
          .count()
          .filter("count % 2 > 0")
          .count()
      == 0;
} catch (Exception e) {
  return false;
}

Column[] columns(Dataset<Row> ds, String... columnNames) {
List<Column> l = new ArrayList<>();
for (String cn : columnNames) {
  l.add(ds.col(cn));
}
return l.stream().toArray(Column[]::new);}

columns method is supplementary and can be replaced by any method that returns Seq columns 方法是补充的,可以被任何返回 Seq 的方法替换

Logic:逻辑:

  1. Union both the datasets, if columns are not matching, it will throw an exception and hence return false.合并两个数据集,如果列不匹配,它将抛出异常并因此返回 false。
  2. If columns are matching then groupBy on all columns and add a column count.如果列匹配,则对所有列进行 groupBy 并添加列数。 Now, all the rows have count in the multiple of 2 (even for duplicate rows).现在,所有行的计数都是 2 的倍数(即使是重复行)。
  3. Check if there is any row that has count not divisible by 2, those are the extra rows.检查是否有任何计数不能被 2 整除的行,这些是额外的行。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM