简体   繁体   English

对于循环Spark数据帧

[英]For loop Spark dataframe

I have a Dataframe df that has, among others, a column of groupID ; 我有一个数据帧df有(其中包括)的列groupID ; that is, each observation belongs to a specific group. 也就是说,每个观察都属于一个特定的群体。 In total there are 8 groups. 总共有8组。 I would like to sample from each groupID a certain percent of observations (say, 20%). 我想从每个groupID抽取一定百分比的观察结果(比如20%)。 Here is my approach of doing this: 这是我做这个的方法:

val sample_df = for ( i <- Array.range(0,7) ) yield {  
             val sel_df = df.filter($"groupID"===i)  
             sel_df.sample(false,0.2,seed1)  
             }  

The result of this code is: 这段代码的结果是:

Array[org.apache.spark.sql.DataFrame] = Array([text: string, groupID: int], [text: string, groupID: int])

I applied flatMap() on sample_df , but I got an error: 我在sample_df上应用了flatMap() ,但是我收到了一个错误:

val flat_df = sample_df.flatMap(x => x)
         <console>:59: error: type mismatch;
         found: org.apache.spark.sql.DataFrame
         required: scala.collection.GenTraversableOnce[?]

How can I get a sampled dataframe? 如何获取采样数据帧?

As far as I understood, you are trying to get RDD of Row . 据我所知,你正试图获得Row RDD For that you can simply call: 为此你可以简单地打电话:

val rows: RDD[Row] = sample_df.rdd

To explain the error you get better, flatMap requires something traversable like Option but you supplied just a Row . 为了解释你变得更好的错误,flatMap需要像Option这样可以遍历的东西,但你只提供了Row

Also, to get all data to the driver, you can call: 此外,要获取驱动程序的所有数据,您可以调用:

val rows: Array[Row] = sample_df.collect

我猜你想要在每组上均匀采样。

sample_df.reduceLeft((result, df) => result.unionAll(df))

It seems to me you just want to take a 20% sample of the entire dataframe? 在我看来,你只想采用整个数据帧的20%样本? If so, then there is no reason to create 8 different dataframes and then union them back. 如果是这样,那么就没有理由创建8个不同的数据帧然后将它们联合起来。

df.sample(false, 0.2, seed)

will do the trick. 会做的。 If you want to do different fractions for each groupID then check out df.stat.sampleBy . 如果您想为每个df.stat.sampleBy执行不同的分数,请查看df.stat.sampleBy If you want to be sure that there is exactly 20% of each class in the sample then you'll have to convert to a PairRDD and use stratified sampling like: 如果你想确保样本中每个类只有 20%,那么你必须转换为PairRDD并使用分层抽样,如:

df.rdd.map(row => (row(groupIDIndex), row)).sampleByKeyExact(false, Map(0 -> 0.2, 1 -> 0.2, ..., 8 -> 0.2), seed)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM