![](/img/trans.png)
[英]Add new column to dataframe depending on interqection of existing columns with pyspark
[英]Pyspark add columns to existing dataframe
我有以下代碼來實現在單個 dataframe 中具有多個條件列。
small_list = ["INFY","TCS", "SBIN", "ICICIBANK"]
frame = spark_frame.where(col("symbol") == small_list[0]).select('close')
## spark frame is a pyspark.sql.dataframe.DataFrame
for single_stock in small_list[1:]:
print(single_stock)
current_stock = spark_frame.where(col("symbol") == single_stock).select(['close'])
current_stock.collect()
frame.collect()
frame = frame.withColumn(single_stock, current_stock.close)
但是當我做frame.collect
時,我得到:
[Row(close=736.85, TCS=736.85, SBIN=736.85, ICICIBANK=736.85),
Row(close=734.7, TCS=734.7, SBIN=734.7, ICICIBANK=734.7),
Row(close=746.0, TCS=746.0, SBIN=746.0, ICICIBANK=746.0),
Row(close=738.85, TCS=738.85, SBIN=738.85, ICICIBANK=738.85)]
這是錯誤的,因為所有值都屬於第一個引用。 我做錯了什么,解決這個問題的最佳方法是什么?
編輯:spark_frame 看起來像這樣
[Row(SYMBOL='LINC', SERIES=' EQ', TIMESTAMP=datetime.datetime(2021, 12, 20, 0, 0), PREVCLOSE=235.6, OPEN=233.95, HIGH=234.0, LOW=222.15, LAST=222.15, CLOSE=224.2, AVG_PRICE=226.63, TOTTRDQTY=6447, TOTTRDVAL=14.61, TOTALTRADES=206, DELIVQTY=5507, DELIVPER=85.42),
Row(SYMBOL='LINC', SERIES=' EQ', TIMESTAMP=datetime.datetime(2021, 12, 21, 0, 0), PREVCLOSE=224.2, OPEN=243.85, HIGH=243.85, LOW=222.85, LAST=226.0, CLOSE=225.6, AVG_PRICE=227.0, TOTTRDQTY=8447, TOTTRDVAL=19.17, TOTALTRADES=266, DELIVQTY=3401, DELIVPER=40.26),
Row(SYMBOL='SCHAEFFLER', SERIES=' EQ', TIMESTAMP=datetime.datetime(2020, 8, 6, 0, 0), PREVCLOSE=3593.9, OPEN=3611.85, HIGH=3618.35, LOW=3542.5, LAST=3594.95, CLOSE=3573.1, AVG_PRICE=3580.73, TOTTRDQTY=12851, TOTTRDVAL=460.16, TOTALTRADES=1886, DELIVQTY=9649, DELIVPER=75.08),
Row(SYMBOL='SCHAEFFLER', SERIES=' EQ', TIMESTAMP=datetime.datetime(2020, 8, 7, 0, 0), PREVCLOSE=3573.1, OPEN=3591.0, HIGH=3591.0, LOW=3520.0, LAST=3548.95, CLOSE=3543.85, AVG_PRICE=3554.6, TOTTRDQTY=2406, TOTTRDVAL=85.52, TOTALTRADES=688, DELIVQTY=1452, DELIVPER=60.35)]
預期結果應如下所示:
[Row(LINC=224.2, SCHAEFFLER=3573.1,
Row(LINC=225.6, SCHAEFFLER=3543.85)]
如果我理解正確的話,你想獲得small_list
列表中股票代碼的收盤價。 最簡單的方法是:
for single_stock in small_list:
spark_frame = spark_frame.withColumn(single_stock,
where(col("symbol") == single_stock, spark_frame.close).otherwise(None))
如果我誤解了你,請告訴我!
我終於能夠通過執行以下操作來規避這個問題,我仍然覺得有可能有一個更優化的解決方案,但這也有效:
small_list = ["INFY","TCS", "SBIN", "ICICIBANK"]
frame = spark_frame.filter(col('symbol')==small_list[0]).select([col('close').alias(single_stock), 'timestamp'])
# frame.withColumnRenamed('close', small_list[0])
for single_stock in small_list[1:]:
print(single_stock)
current_stock = spark_frame.filter(col('symbol')==single_stock).select(['close', 'timestamp'])
frame = frame.join(current_stock, "timestamp", "inner")
結果看起來像:
+-------------------+------+-------+------+---------+
| timestamp| close| TCS| SBIN|ICICIBANK|
+-------------------+------+-------+------+---------+
|2020-01-01 00:00:00|736.85| 2167.6|334.45| 536.75|
|2020-01-02 00:00:00| 734.7|2157.65| 339.3| 540.6|
|2020-01-03 00:00:00| 746.0|2200.65| 333.7| 538.85|
+-------------------+------+-------+------+---------+
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.