[英]How to create a new column in pyspark
I have a spark dataframe as below我有一个火花 dataframe 如下
a1 ![]() |
a2 ![]() |
a3 ![]() |
a4 ![]() |
---|---|---|---|
A![]() |
12 ![]() |
9 ![]() |
1 ![]() |
B![]() |
14 ![]() |
13 ![]() |
1 ![]() |
C ![]() |
7 ![]() |
3 ![]() |
0 ![]() |
I want to create a new column A5 such with conditioning such as我想创建一个新列 A5 ,例如
if a1 = A then a5 = Car
if a2>0 then a5 = Bus
if a3>0 and a4 =1 then a5 = Bike
The desired output should as below所需的 output 应如下所示
a1 ![]() |
a2 ![]() |
a3 ![]() |
a4 ![]() |
a5 ![]() |
---|---|---|---|---|
A![]() |
12 ![]() |
9 ![]() |
1 ![]() |
Car![]() |
A![]() |
12 ![]() |
9 ![]() |
1 ![]() |
Bus![]() |
A![]() |
12 ![]() |
9 ![]() |
1 ![]() |
Bike![]() |
B![]() |
14 ![]() |
13 ![]() |
1 ![]() |
Bus![]() |
B![]() |
14 ![]() |
13 ![]() |
1 ![]() |
Bike![]() |
C ![]() |
7 ![]() |
3 ![]() |
0 ![]() |
Bus![]() |
Please help on how to add this new column.请帮助如何添加这个新列。 Thank you in advance
先感谢您
You can define your own function您可以定义自己的 function
from pyspark.sql import functions as f
from pyspark.sql.types import StringType
def myfunction(a1, a2, a3):
if a1 == "A":
return "Car"
elif a2 > 0:
return "Bus"
elif a3 > 0 and a4 == 1:
return "Bike"
df = df.withColumn("a5", f.udf(myfunction(df.a1, df.a2, df.a3), StringType()))
df = spark.createDataFrame(
[
('A','12','9','1'),
('B','14','13','1'),
('C','7','3','0')
],
['a1','a2','a3','a4']
)
from pyspark.sql.functions import when, lit, col
res = df.withColumn("a5",when(df.a1 == 'A', lit('Car')))\
.unionByName(df.withColumn("a5",when(df.a2 > 0, lit('Bus'))))\
.unionByName(df.withColumn("a5",when((df.a3 > 0)&(df.a4 == 1), lit('Bike'))))\
.filter(col('a5').isNotNull())\
.orderBy(col('a1').asc())
res.show()
# +---+---+---+---+----+
# | a1| a2| a3| a4| a5|
# +---+---+---+---+----+
# | A| 12| 9| 1| Bus|
# | A| 12| 9| 1| Car|
# | A| 12| 9| 1|Bike|
# | B| 14| 13| 1| Bus|
# | B| 14| 13| 1|Bike|
# | C| 7| 3| 0| Bus|
# +---+---+---+---+----+
The (probably) best approach performance-wise is creating some temporary columns according to your logic, add them to an array, then explode them to get more rows. (可能)最好的性能方法是根据您的逻辑创建一些临时列,将它们添加到数组中,然后分解它们以获得更多行。
from pyspark.sql import functions as F
(df
.withColumn('a5_1', F.when(F.col('a1') == 'A', 'Car'))
.withColumn('a5_2', F.when(F.col('a2') > 0, 'Bus'))
.withColumn('a5_3', F.when((F.col('a3') > 0) & (F.col('a4') == 1), 'Bike'))
.withColumn('a5', F.array_except(F.array('a5_1', 'a5_2', 'a5_3'), F.array(F.lit(None))))
.select(df.columns + [F.explode('a5').alias('a5')])
.show()
)
+---+---+---+---+----+
| a1| a2| a3| a4| a5|
+---+---+---+---+----+
| A| 12| 9| 1| Car|
| A| 12| 9| 1| Bus|
| A| 12| 9| 1|Bike|
| B| 14| 13| 1| Bus|
| B| 14| 13| 1|Bike|
| C| 7| 3| 0| Bus|
+---+---+---+---+----+
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.