簡體   English   中英

如何有效地 select dataframe 列在 Spark 中包含某個值?

[英]How to efficiently select dataframe columns containing a certain value in Spark?

假設您在 spark(字符串類型)中有一個 dataframe,並且您想要刪除任何包含“foo”的列。 在下面的示例 dataframe 中,您將刪除列“c2”和“c3”,但保留“c1”。 但是,我希望將解決方案推廣到大量列和行。

    +-------------------+
    |   c1|   c2|     c3|
    +-------------------+
    | this|  foo|  hello|
    | that|  bar|  world|
    |other|  baz| foobar|
    +-------------------+

我的解決方案是掃描 dataframe 中的每一列,然后使用 dataframe API 和內置函數聚合結果。 因此,可以像這樣掃描每一列(我是 scala 的新手,請原諒語法錯誤):

df = df.select(df.columns.map(c => col(c).like("foo"))

從邏輯上講,我會有一個中間 dataframe 像這樣:

    +--------------------+
    |    c1|    c2|    c3|
    +--------------------+
    | false|  true| false|
    | false| false| false|
    | false| false|  true|
    +--------------------+

然后將其聚合為一行以讀取需要刪除的列。

exprs = df.columns.map( c => max(c).alias(c))

drop = df.agg(exprs.head, exprs.tail: _*)

    +--------------------+
    |    c1|    c2|    c3|
    +--------------------+
    | false|  true|  true|
    +--------------------+

現在可以刪除任何包含 true 的列。

我的問題是:有沒有更好的方法來做到這一點,性能明智? 在這種情況下,一旦找到“foo”,spark 是否會停止掃描列? 數據的存儲方式是否重要(鑲木地板有幫助嗎?)。

謝謝,我是新來的,所以請告訴我如何改進這個問題。

根據您的數據,例如,如果您有很多foo值,下面的代碼可能會更有效地執行:

val colsToDrop = df.columns.filter{ c =>
  !df.where(col(c).like("foo")).limit(1).isEmpty
}

df.drop(colsToDrop: _*)

更新:刪除了多余.limit(1)

val colsToDrop = df.columns.filter{ c =>
  !df.where(col(c).like("foo")).isEmpty
}

df.drop(colsToDrop: _*)

遵循您的邏輯的答案(正確計算),但我認為另一個答案更好,對於后代和您通過 Scala 提高的能力更是如此。 我不確定另一個答案實際上是否有效,但這也不是。 不確定鑲木地板是否有幫助,很難衡量。

另一種選擇是在驅動程序上編寫一個循環並訪問每一列,然后由於柱狀、統計和下推,鑲木地板將被使用。

import org.apache.spark.sql.functions._
def myUDF = udf((cols: Seq[String], cmp: String) => cols.map(code => if (code == cmp) true else false ))

val df = sc.parallelize(Seq(
   ("foo", "abc", "sss"),
   ("bar", "fff", "sss"),
   ("foo", "foo", "ddd"),
   ("bar", "ddd", "ddd")
   )).toDF("a", "b", "c")

val res = df.select($"*", array(df.columns.map(col): _*).as("colN"))
            .withColumn( "colres", myUDF( col("colN") , lit("foo") )  )

res.show()
res.printSchema()
val n = 3
val res2 = res.select( (0 until n).map(i => col("colres")(i).alias(s"c${i+1}")): _*)
res2.show(false)

val exprs = res2.columns.map( c => max(c).alias(c))
val drop = res2.agg(exprs.head, exprs.tail: _*)
drop.show(false)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM