简体   繁体   中英

Multiplying rows of Sparse vectors in pyspark SQL DataFrame

I'm having difficulties multiplying elements of columns in a SQL data frame.

sv1 = Vectors.sparse(3, [0, 2], [1.0, 3.0])
sv2 = Vectors.sparse(3, [0, 1], [2.0, 4.0])

def xByY(x,y):
  return np.multiply(x,y)

print(xByY(sv1, sv2))

The above works.

But the below doesn't.

xByY_udf = udf(xByY)

tempDF = sqlContext.createDataFrame([(sv1, sv2), (sv1, sv2)], ('v1', 'v2'))
tempDF.show()

print(tempDF.select(xByY_udf('v1', 'v2')).show())

Many thanks!

If you want your udf to return a SparseVector , we'll first need to modify the output of your function, and secondly set the output schema of the udf to VectorUDT() :

To declare a SparseVector , we need the size of the original array, and both the indices and the values of the non-zero elements. We can find these using len() and list comprehensions if the intermediate result of the multiplication is a list :

from pyspark.ml.linalg import Vectors, VectorUDT

def xByY(x,y):
  res = np.multiply(x,y).tolist()
  vec_args =  len(res), [i for i,x in enumerate(res) if x != 0], [x for x in res if x != 0] 
  return Vectors.sparse(*vec_args)  

Now we can declare our udf and test it:

xByY_udf = udf(xByY, VectorUDT())
tempDF.select(xByY_udf('v1', 'v2')).show()
+-------------+
| xByY(v1, v2)|
+-------------+
|(3,[0],[2.0])|
|(3,[0],[2.0])|
+-------------+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM