简体   繁体   English

Pyspark UDF比较稀疏向量

[英]Pyspark UDF to compare Sparse Vectors

I am trying to write a pyspark UDF that will compare two Sparse Vectors for me. 我正在尝试编写一个pyspark UDF,它将为我比较两个稀疏向量。 What I would like to write is: 我想写的是:

from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType, FloatType

def compare(req_values, values):
    return [req for req in req_values.indices if req not in values.indices]

compare_udf = udf(compare, ArrayType(IntegerType()))

display(data.limit(5).select('*', compare_udf('req_values', 'values').alias('missing')))

However, when I run this code I get the following error message: 但是,当我运行此代码时,出现以下错误消息:

SparkException: Job aborted due to stage failure: Task 0 in stage 129.0 failed 4 times, most recent failure: Lost task 0.3 in stage 129.0 (TID 1256, 10.139.64.15, executor 2): net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype)

I have run into similar problems to this before which relate to the dataframe type being unable to deal with numpy data types. 在此之前,我遇到了类似的问题,这与数据帧类型无法处理numpy数据类型有关。 Previously I have been able to solve these issues by coercing the numpy array into a list before returning it, but in this case it seems that I am unable to even pull the data out of the SparseVector, for example even the following does not work: 以前,我能够通过将numpy数组强制返回到列表中来解决这些问题,但是在这种情况下,我什至无法将数据从SparseVector中拉出,例如,即使以下操作也不起作用:

def compare(req_values, values):
    return req_values.indices[0]   

compare_udf = udf(compare, IntegerType())

I have been able to circumvent the issues using an RDD but I am still find this a frustrating limitation with the pyspark UDF. 我已经能够使用RDD来解决问题,但是我仍然发现pyspark UDF令人沮丧。 Any advice or help appreciated! 任何建议或帮助表示赞赏!

I seem to have solved this problem myself - the issue comes down to the fact that the constituent components of the mllib Sparse Vector are numpy types, which are themselves not supported by the pyspark DataFrame. 我似乎自己已经解决了这个问题-问题归结为一个事实,即mllib Sparse Vector的组成组件是numpy类型,而pyspark DataFrame不支持它们本身。 The following adjusted code works: 以下调整后的代码有效:

from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType, FloatType

def compare(req_values, values):
    return [int(req) for req in req_values.indices if req not in values.indices]

compare_udf = udf(compare, ArrayType(IntegerType()))

display(data.limit(5).select('*', compare_udf('req_values', 'values').alias('missing')))

While this works it seems somewhat counter-intuitive to me that the pyspark DataFrame would support a constructed datatype (SparseVector) but not it's constituent parts by themselves (numpy integers) nor provide a more enlightening error message explaining the problem. 虽然这项工作可行,但对我来说似乎有点违反直觉,pyspark DataFrame将支持构造的数据类型(SparseVector),但不支持其本身的构成部分(numpy整数),也未提供更启发性的错误消息来解释问题。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM