簡體   English   中英

將 udf 應用於多個列並使用 numpy 操作

[英]apply udf to multiple columns and use numpy operations

我在 pyspark 中有一個名為 dataframe 的結果,我想應用一個 udf 來創建一個新列,如下所示:

result = sqlContext.createDataFrame([(138,5,10), (128,4,10), (112,3,10), (120,3,10), (189,1,10)]).withColumnRenamed("_1","count").withColumnRenamed("_2","df").withColumnRenamed("_3","docs")
@udf("float")
def newFunction(arr):
    return (1 + np.log(arr[0])) * np.log(arr[2]/arr[1])

result=result.withColumn("new_function_result",newFunction_udf(array("count","df","docs")))

列數、df、docs 都是 integer 列。但這會返回

Py4JError:調用 z:org.apache.spark.sql.functions.col 時出錯。 Trace: py4j.Py4JException: Method col([class java.util.ArrayList]) does not exist at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318) at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:339 ) at py4j.Gateway.invoke(Gateway.java:274) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run (GatewayConnection.java:214) 在 java.lang.Thread.run(Thread.java:748)

當我嘗試通過一列並獲得其中的正方形時,它工作正常。

任何幫助表示贊賞。

該錯誤消息具有誤導性,但試圖告訴您您的 function 不返回浮點數。 您的 function 返回numpy.float64類型的值,您可以使用 VectorUDT 類型獲取該值(函數:下面示例中的newFunctionVector )。 Another way to make use of numpy is by casting the numpy type numpy.float64 to the python type float (Function: newFunctionWithArray in the example below).

最后但同樣重要的是,沒有必要調用數組,因為 udfs 可以使用多個參數(下例中的函數: newFunction )。

import numpy as np
from pyspark.sql.functions import udf, array
from pyspark.sql.types import FloatType
from pyspark.mllib.linalg import Vectors, VectorUDT

result = sqlContext.createDataFrame([(138,5,10), (128,4,10), (112,3,10), (120,3,10), (189,1,10)], ["count","df","docs"])

def newFunctionVector(arr):
    return (1 + np.log(arr[0])) * np.log(arr[2]/arr[1])

@udf("float")
def newFunctionWithArray(arr):
    returnValue = (1 + np.log(arr[0])) * np.log(arr[2]/arr[1])
    return returnValue.item()

@udf("float")
def newFunction(count, df, docs):
    returnValue = (1 + np.log(count)) * np.log(docs/df)
    return returnValue.item()


vector_udf = udf(newFunctionVector, VectorUDT())

result=result.withColumn("new_function_result", newFunction("count","df","docs"))

result=result.withColumn("new_function_result_WithArray", newFunctionWithArray(array("count","df","docs")))

result=result.withColumn("new_function_result_Vector", newFunctionWithArray(array("count","df","docs")))

result.printSchema()

result.show()

Output:

root 
|-- count: long (nullable = true) 
|-- df: long (nullable = true) 
|-- docs: long (nullable = true) 
|-- new_function_result: float (nullable = true) 
|-- new_function_result_WithArray: float (nullable = true) 
|-- new_function_result_Vector: float (nullable = true)

+-----+---+----+-------------------+-----------------------------+--------------------------+ 
|count| df|docs|new_function_result|new_function_result_WithArray|new_function_result_Vector|
+-----+---+----+-------------------+-----------------------------+--------------------------+ 
|  138|  5|  10|           4.108459|                     4.108459|                  4.108459| 
|  128|  4|  10|           5.362161|                     5.362161|                  5.362161|
|  112|  3|  10|          6.8849173|                    6.8849173|                 6.8849173|
|  120|  3|  10|           6.967983|                     6.967983|                  6.967983|
|  189|  1|  10|          14.372153|                    14.372153|                 14.372153|  
+-----+---+----+-------------------+-----------------------------+--------------------------+

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM