简体   繁体   English

将 spark DataFrame 列转换为 python 列表

[英]Convert spark DataFrame column to python list

I work on a dataframe with two column, mvv and count.我处理具有两列 mvv 和计数的数据框。

+---+-----+
|mvv|count|
+---+-----+
| 1 |  5  |
| 2 |  9  |
| 3 |  3  |
| 4 |  1  |

i would like to obtain two list containing mvv values and count value.我想获得两个包含 mvv 值和计数值的列表。 Something like类似的东西

mvv = [1,2,3,4]
count = [5,9,3,1]

So, I tried the following code: The first line should return a python list of row.所以,我尝试了以下代码:第一行应该返回一个 Python 行列表。 I wanted to see the first value:我想看到第一个值:

mvv_list = mvv_count_df.select('mvv').collect()
firstvalue = mvv_list[0].getInt(0)

But I get an error message with the second line:但是我在第二行收到一条错误消息:

AttributeError: getInt属性错误:getInt

See, why this way that you are doing is not working.看,为什么你正在做的这种方式不起作用。 First, you are trying to get integer from a Row Type, the output of your collect is like this:首先,你试图从一个类型中获取整数,你的收集的输出是这样的:

>>> mvv_list = mvv_count_df.select('mvv').collect()
>>> mvv_list[0]
Out: Row(mvv=1)

If you take something like this:如果你采取这样的事情:

>>> firstvalue = mvv_list[0].mvv
Out: 1

You will get the mvv value.您将获得mvv值。 If you want all the information of the array you can take something like this:如果您想要数组的所有信息,您可以采用以下方法:

>>> mvv_array = [int(row.mvv) for row in mvv_list.collect()]
>>> mvv_array
Out: [1,2,3,4]

But if you try the same for the other column, you get:但是,如果您对另一列尝试相同的操作,则会得到:

>>> mvv_count = [int(row.count) for row in mvv_list.collect()]
Out: TypeError: int() argument must be a string or a number, not 'builtin_function_or_method'

This happens because count is a built-in method.发生这种情况是因为count是一个内置方法。 And the column has the same name as count .并且该列与count同名。 A workaround to do this is change the column name of count to _count :一种解决方法是将count的列名更改为_count

>>> mvv_list = mvv_list.selectExpr("mvv as mvv", "count as _count")
>>> mvv_count = [int(row._count) for row in mvv_list.collect()]

But this workaround is not needed, as you can access the column using the dictionary syntax:但不需要此解决方法,因为您可以使用字典语法访问该列:

>>> mvv_array = [int(row['mvv']) for row in mvv_list.collect()]
>>> mvv_count = [int(row['count']) for row in mvv_list.collect()]

And it will finally work!它最终会起作用!

跟随一个班轮给出你想要的清单。

mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()

This will give you all the elements as a list.这将为您提供所有元素作为列表。

mvv_list = list(
    mvv_count_df.select('mvv').toPandas()['mvv']
)

以下代码将帮助您

mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()

I ran a benchmarking analysis and list(mvv_count_df.select('mvv').toPandas()['mvv']) is the fastest method.我运行了一个基准分析和list(mvv_count_df.select('mvv').toPandas()['mvv'])是最快的方法。 I'm very surprised.我很惊讶。

I ran the different approaches on 100 thousand / 100 million row datasets using a 5 node i3.xlarge cluster (each node has 30.5 GBs of RAM and 4 cores) with Spark 2.4.5.我使用 Spark 2.4.5 的 5 节点 i3.xlarge 集群(每个节点有 30.5 GB 的 RAM 和 4 个内核)在 10 万/1 亿行数据集上运行了不同的方法。 Data was evenly distributed on 20 snappy compressed Parquet files with a single column.数据均匀分布在 20 个压缩的 Parquet 文件中,只有一列。

Here's the benchmarking results (runtimes in seconds):这是基准测试结果(以秒为单位的运行时间):

+-------------------------------------------------------------+---------+-------------+
|                          Code                               | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect()    |     0.4 | 55.3        |
| list(df.select('col_name').toPandas()['col_name'])          |     0.4 | 17.5        |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()|     0.9 | 69          |
| [row[0] for row in df.select('col_name').collect()]         |     1.0 | OOM         |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] |     1.2 | *           |
+-------------------------------------------------------------+---------+-------------+

* cancelled after 800 seconds

Golden rules to follow when collecting data on the driver node:在驱动程序节点上收集数据时要遵循的黄金规则:

  • Try to solve the problem with other approaches.尝试用其他方法解决问题。 Collecting data to the driver node is expensive, doesn't harness the power of the Spark cluster, and should be avoided whenever possible.将数据收集到驱动程序节点是昂贵的,不能利用 Spark 集群的功能,应尽可能避免。
  • Collect as few rows as possible.收集尽可能少的行。 Aggregate, deduplicate, filter, and prune columns before collecting the data.在收集数据之前聚合、重复数据删除、过滤和修剪列。 Send as little data to the driver node as you can.尽可能少地向驱动程序节点发送数据。

toPandas was significantly improved in Spark 2.3 . toPandas在 Spark 2.3 中得到了显着改进 It's probably not the best approach if you're using a Spark version earlier than 2.3.如果您使用的是早于 2.3 的 Spark 版本,这可能不是最好的方法。

See here for more details / benchmarking results.有关更多详细信息/基准测试结果,请参见此处

On my data I got these benchmarks:在我的数据上,我得到了这些基准:

>>> data.select(col).rdd.flatMap(lambda x: x).collect()

0.52 sec 0.52 秒

>>> [row[col] for row in data.collect()]

0.271 sec 0.271 秒

>>> list(data.select(col).toPandas()[col])

0.427 sec 0.427 秒

The result is the same结果是一样的

If you get the error below :如果您收到以下错误:

AttributeError: 'list' object has no attribute 'collect' AttributeError: 'list' 对象没有属性 'collect'

This code will solve your issues :此代码将解决您的问题:

mvv_list = mvv_count_df.select('mvv').collect()

mvv_array = [int(i.mvv) for i in mvv_list]

A possible solution is using the collect_list() function from pyspark.sql.functions .一种可能的方案是使用collect_list()从函数pyspark.sql.functions This will aggregate all column values into a pyspark array that is converted into a python list when collected:这会将所有列值聚合到一个 pyspark 数组中,该数组在收集时转换为 python 列表:

mvv_list   = df.select(collect_list("mvv")).collect()[0][0]
count_list = df.select(collect_list("count")).collect()[0][0] 

Let's create the dataframe in question让我们创建有问题的数据框

df_test = spark.createDataFrame(
    [
        (1, 5),
        (2, 9),
        (3, 3),
        (4, 1),
    ],
    ['mvv', 'count']
)
df_test.show()

Which gives这给

+---+-----+
|mvv|count|
+---+-----+
|  1|    5|
|  2|    9|
|  3|    3|
|  4|    1|
+---+-----+

and then apply rdd.flatMap(f).collect() to get the list然后应用 rdd.flatMap(f).collect() 来获取列表

test_list = df_test.select("mvv").rdd.flatMap(list).collect()
print(type(test_list))
print(test_list)

which gives这给

<type 'list'>
[1, 2, 3, 4]

Despite many answeres, some of them wont work when you need a list to be used in combination with when and isin commands.尽管有很多答案,但当您需要将列表与whenisin命令结合使用when ,其中一些将不起作用。 The simplest yet effective approach resulting a flat list of values is by using list comprehension and [0] to avoid row names:产生一个平面值列表的最简单而有效的方法是使用列表理解和[0]来避免行名称:

flatten_list_from_spark_df=[i[0] for i in df.select("your column").collect()]

The other approach is to use panda data frame and then use the list function but it is not convenient and as effective as this.a另一种方法是使用panda数据框,然后使用list函数,但不如this.a方便有效。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM