[英]Multiplying rows of Sparse vectors in pyspark SQL DataFrame
我在將SQL數據框中的列元素相乘時遇到困難。
sv1 = Vectors.sparse(3, [0, 2], [1.0, 3.0])
sv2 = Vectors.sparse(3, [0, 1], [2.0, 4.0])
def xByY(x,y):
return np.multiply(x,y)
print(xByY(sv1, sv2))
以上作品。
但是下面沒有。
xByY_udf = udf(xByY)
tempDF = sqlContext.createDataFrame([(sv1, sv2), (sv1, sv2)], ('v1', 'v2'))
tempDF.show()
print(tempDF.select(xByY_udf('v1', 'v2')).show())
非常感謝!
如果你希望你的udf
返回一個SparseVector
,我們首先需要修改你的函數的輸出,其次設定的輸出模式udf
到VectorUDT()
要聲明SparseVector
,我們需要原始數組的大小 ,以及索引和非零元素的值 。 如果乘法的中間結果是一個list
我們可以使用len()
和list comprehensions找到它們:
from pyspark.ml.linalg import Vectors, VectorUDT
def xByY(x,y):
res = np.multiply(x,y).tolist()
vec_args = len(res), [i for i,x in enumerate(res) if x != 0], [x for x in res if x != 0]
return Vectors.sparse(*vec_args)
現在我們可以聲明我們的udf
並對其進行測試:
xByY_udf = udf(xByY, VectorUDT())
tempDF.select(xByY_udf('v1', 'v2')).show()
+-------------+
| xByY(v1, v2)|
+-------------+
|(3,[0],[2.0])|
|(3,[0],[2.0])|
+-------------+
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.