简体   繁体   中英

Pyspark: sum error with SparseVector

Suppose I have a SparseVector and I want to sum its values, eg

v = SparseVector(15557, [3, 40, 45, 103, 14356], np.ones(5))
v.values.sum()

5.0

This works well. Now I want to do the same thing by means of a udf , because I have a DataFrame with a column of SparseVector . Here I get an error I don't understand:

from pyspark.sql import functions as f

def sum_vector(vector):
    return vector.values.sum()

sum_vector_udf = f.udf(lambda x: sum_vector(x))

sum_vector_udf(v)

----

AttributeError                            Traceback (most recent call last)
<ipython-input-38-b4d44c2ef561> in <module>()
      1 v = SparseVector(15557, [3, 40, 45, 103, 14356], np.ones(5))
      2 
----> 3 sum_vector_udf(v)
      4 #v.values.sum()

~/anaconda3/lib/python3.6/site-packages/pyspark/sql/functions.py in wrapper(*args)
   1955         @functools.wraps(f)
   1956         def wrapper(*args):
-> 1957             return udf_obj(*args)
   1958 
   1959         wrapper.func = udf_obj.func

~/anaconda3/lib/python3.6/site-packages/pyspark/sql/functions.py in __call__(self, *cols)
   1916         judf = self._judf
   1917         sc = SparkContext._active_spark_context
-> 1918         return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
   1919 
   1920 

~/anaconda3/lib/python3.6/site-packages/pyspark/sql/column.py in _to_seq(sc, cols, converter)
     58     """
     59     if converter:
---> 60         cols = [converter(c) for c in cols]
     61     return sc._jvm.PythonUtils.toSeq(cols)
     62 

~/anaconda3/lib/python3.6/site-packages/pyspark/sql/column.py in <listcomp>(.0)
     58     """
     59     if converter:
---> 60         cols = [converter(c) for c in cols]
     61     return sc._jvm.PythonUtils.toSeq(cols)
     62 

~/anaconda3/lib/python3.6/site-packages/pyspark/sql/column.py in _to_java_column(col)
     46         jcol = col._jc
     47     else:
---> 48         jcol = _create_column_from_name(col)
     49     return jcol
     50 

~/anaconda3/lib/python3.6/site-packages/pyspark/sql/column.py in _create_column_from_name(name)
     39 def _create_column_from_name(name):
     40     sc = SparkContext._active_spark_context
---> 41     return sc._jvm.functions.col(name)
     42 
     43 

~/anaconda3/lib/python3.6/site-packages/py4j/java_gateway.py in __call__(self, *args)
   1122 
   1123     def __call__(self, *args):
-> 1124         args_command, temp_args = self._build_args(*args)
   1125 
   1126         command = proto.CALL_COMMAND_NAME +\

~/anaconda3/lib/python3.6/site-packages/py4j/java_gateway.py in _build_args(self, *args)
   1092 
   1093         args_command = "".join(
-> 1094             [get_command_part(arg, self.pool) for arg in new_args])
   1095 
   1096         return args_command, temp_args

~/anaconda3/lib/python3.6/site-packages/py4j/java_gateway.py in <listcomp>(.0)
   1092 
   1093         args_command = "".join(
-> 1094             [get_command_part(arg, self.pool) for arg in new_args])
   1095 
   1096         return args_command, temp_args

~/anaconda3/lib/python3.6/site-packages/py4j/protocol.py in get_command_part(parameter, python_proxy_pool)
    287             command_part += ";" + interface
    288     else:
--> 289         command_part = REFERENCE_TYPE + parameter._get_object_id()
    290 
    291     command_part += "\n"

AttributeError: 'SparseVector' object has no attribute '_get_object_id'

I really don't understand, I'm writing exactly the same thing in two different ways. Any tips?

This happens because udf doesn't support NumPy types as return type.

>>> type(v.values.sum())
<class 'numpy.float64'>

You should cast the result to standard Python type:

df = spark.createDataFrame([(v, )], ["v"])

@udf("double")
def sum_vector(vector):
    return vector.values.sum().tolist()

or

@udf("double")
def sum_vector(vector):
    return float(vector.values.sum())

In both cases you'll get the expected result:

df.select(sum_vector("v")).show()
+-------------+
|sum_vector(v)|
+-------------+
|          5.0|
+-------------+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM