繁体   English   中英

无法从pyspark RDD的map方法访问类方法

[英]Not able to access class methods from pyspark RDD's map method

在将pyspark集成到应用程序的代码库中时,我无法在RDD的map方法中引用类的方法。 我用一个简单的例子重复了这个问题,如下所示

这是我定义的一个虚拟类,该类仅向从RDD派生的RDD的每个元素添加一个数字,该RDD是一个类属性:

class Test:

    def __init__(self):
        self.sc = SparkContext()
        a = [('a', 1), ('b', 2), ('c', 3)]
        self.a_r = self.sc.parallelize(a)

    def add(self, a, b):
        return a + b

    def test_func(self, b):
        c_r = self.a_r.map(lambda l: (l[0], l[1] * 2))
        v = c_r.map(lambda l: self.add(l[1], b))
        v_c = v.collect()
        return v_c

test_func()在RDD v上调用map()方法,然后依次在v每个元素上调用add()方法。 调用test_func()会引发以下错误:

pickle.PicklingError: Could not serialize object: Exception: It appears that you are attempting to reference SparkContext from a broadcast variable, action, or transformation. SparkContext can only be used on the driver, not in code that it run on workers. For more information, see SPARK-5063.

现在,当我将add()方法移出类时,例如:

def add(self, a, b):
    return a + b

class Test:

    def __init__(self):
        self.sc = SparkContext()
        a = [('a', 1), ('b', 2), ('c', 3)]
        self.a_r = self.sc.parallelize(a)

    def test_func(self, b):

        c_r = self.a_r.map(lambda l: (l[0], l[1] * 2))
        v = c_r.map(lambda l: add(l[1], b))
        v_c = v.collect()

        return v_c

调用test_func()现在可以正常工作。

[7, 9, 11]

为什么会发生这种情况,如何将类方法传递给RDD的map()方法?

发生这种情况的原因是,当pyspark尝试序列化您的函数(以将其发送给worker)时,它还需要序列化Test类的实例(因为要传递给map的函数在self具有对该实例的引用)。 该实例引用了spark上下文。 您需要确保SparkContextRDD未被任何序列化并发送给worker的对象引用。 SparkContext仅需要存在于驱动程序中。

这应该工作:

在文件testspark.py

class Test(object):
    def add(self, a, b):
        return a + b

    def test_func(self, a_r, b):
        c_r = a_r.map(lambda l: (l[0], l[1] * 2))
        # now `self` has no reference to the SparkContext()
        v = c_r.map(lambda l: self.add(l[1], b)) 
        v_c = v.collect()
        return v_c

在您的主脚本中:

from pyspark import SparkContext
from testspark import Test

sc = SparkContext()
a = [('a', 1), ('b', 2), ('c', 3)]
a_r = sc.parallelize(a)

test = Test()
test.test_func(a_r, 5) # should give [7, 9, 11]

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM