繁体   English   中英

Pyspark Java.lang.OutOfMemoryError:Java 堆空间

[英]Pyspark Java.lang.OutOfMemoryError: Java heap space

我正在使用在本地机器上运行的 spark 解决问题。

我正在从本地磁盘读取镶木地板文件并将其存储到数据帧中。

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

spark = SparkSession.builder\
    .config("spark.driver.memory","4g")\
    .config("spark.executor.memory","4g")\
    .config("spark.driver.maxResultSize","2g")\
    .getOrCreate()

content = spark.read.parquet('./files/file')

因此, Content Dataframe 内容大约有 500k 行,即

+-----------+----------+
|EMPLOYEE_ID|MANAGER_ID|
+-----------+----------+
|        100|         0|
|        101|       100|
|        102|       100|
|        103|       100|
|        104|       100|
|        105|       100|
|        106|       101|
|        101|       101|
|        101|       101|
|        101|       101|
|        101|       102|
|        101|       102|
       .           .
       .           .
       .           .

我编写此代码是为了根据其层次结构为每个 EMPLOYEE_ID 提供一个EMPLOYEE_LEVEL

# Assign EMPLOYEE_LEVEL 1 WHEN MANAGER_ID is 0 ELSE NULL
content_df = content.withColumn("EMPLOYEE_LEVEL", when(col("MANAGER_ID") == 0, 1).otherwise(lit('')))

level_df = content_df.select("*").filter("Level = 1")

level = 1
while True:
    ldf = level_df
    temp_df = content_df.join(
        ldf,
        ((ldf["EMPLOYEE_LEVEL"] == level) & 
         (ldf["EMPLOYEE_ID"] == content_df["MANAGER_ID"])),
        "left") \
    .withColumn("EMPLOYEE_LEVEL",ldf["EMPLOYEE_LEVEL"]+1)\
    .select("EMPLOYEE_ID","MANAGER_ID","EMPLOYEE_LEVEL")\
    .filter("EMPLOYEE_LEVEL IS NOT NULL")\
    .distinct()
    
    if temp_df.count() == 0:
        break
    level_df = level_df.union(temp_df)
    level += 1

它正在运行,但执行速度非常慢,并且在一段时间后会出现此错误。

Py4JJavaError: An error occurred while calling o383.count.
: java.lang.OutOfMemoryError: Java heap space
    at scala.collection.immutable.List.$colon$colon(List.scala:117)
    at scala.collection.immutable.List.$plus$colon(List.scala:220)
    at org.apache.spark.sql.catalyst.expressions.String2TrimExpression.children(stringExpressions.scala:816)
    at org.apache.spark.sql.catalyst.expressions.String2TrimExpression.children$(stringExpressions.scala:816)
    at org.apache.spark.sql.catalyst.expressions.StringTrim.children(stringExpressions.scala:948)
    at org.apache.spark.sql.catalyst.trees.TreeNode.withNewChildren(TreeNode.scala:351)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:595)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:486)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:486)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$Lambda$1822/0x0000000100d21040.apply(Unknown Source)
    at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
    at scala.collection.TraversableLike$$Lambda$61/0x00000001001d2040.apply(Unknown Source)
    at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
    at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
    at scala.collection.TraversableLike.map(TraversableLike.scala:286)
    at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
    at scala.collection.AbstractTraversable.map(Traversable.scala:108)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:595)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:486)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:486)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$Lambda$1822/0x0000000100d21040.apply(Unknown Source)
    at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren(TreeNode.scala:1148)
    at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren$(TreeNode.scala:1147)
    at org.apache.spark.sql.catalyst.expressions.BinaryExpression.mapChildren(Expression.scala:555)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:486)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:486)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$Lambda$1822/0x0000000100d21040.apply(Unknown Source)
    at org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren(TreeNode.scala:1122)
    at org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren$(TreeNode.scala:1121)
    at org.apache.spark.sql.catalyst.expressions.UnaryExpression.mapChildren(Expression.scala:467)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:486)

我尝试了许多解决方案,包括增加驱动程序和执行程序内存,对数据帧使用 cache() 和 persist() 也对我不起作用。

我正在使用 Spark 3.2.1 Spark

任何帮助将不胜感激。 谢谢你。

速度慢的一个原因是,在您的代码中,它会循环进行多次迭代,因此数据帧的计算行数每次循环都会增加,使用持久化/缓存可以提高效率。 但这不是错误的原因。

我使用 pyspark 编写了非常相似的代码,并且遇到了同样的问题。 兄弟,你已经有解决方案了吗?

把我的代码放在下面:

epoch_cnt = 0
while True:
    print('hahaha1')
    
    print('cached df', len(spark.sparkContext._jsc.getPersistentRDDs().items()))
    
    singer_pairs_undirected_ungrouped = singer_pairs_undirected.join(old_song_group_kernel,
                                                                     on=singer_pairs_undirected['src'] ==
                                                                        old_song_group_kernel['id'],
                                                                     how='left').filter(F.col('id').isNull()) \
        .select('src', 'dst')

    windowSpec = Window.partitionBy("src").orderBy(F.col("song_group_id_cnt").desc())

    singer_pairs_vote = singer_pairs_undirected_ungrouped.join(old_song_group_kernel,
                                                               on=singer_pairs_undirected_ungrouped['dst'] ==
                                                                  old_song_group_kernel['id'], how='inner') \
        .groupBy('src', 'song_group_id') \
        .agg(F.count('song_group_id').alias('song_group_id_cnt')) \
        .withColumn('song_group_id_cnt_rnk', F.row_number().over(windowSpec)) \
        .filter(F.col('song_group_id_cnt_rnk') == 1)

    singer_pairs_vote_output = singer_pairs_vote.select('src', 'song_group_id') \
        .withColumnRenamed('src', 'id')
    
    print('hahaha5')

    new_song_group_kernel = old_song_group_kernel.union(singer_pairs_vote_output) \
        .select('id', 'song_group_id').dropDuplicates().persist()
    
    print('hahaha9')
    
    current_kernel_cnt = new_song_group_kernel.count()
    
    print('hahaha2')
    old_song_group_kernel.unpersist(True)
    print('hahaha3')
    old_song_group_kernel = new_song_group_kernel

    epoch_cnt += 1
    print('epoch rounds: ', epoch_cnt)
    print('previous kernel count: ', previous_kernel_cnt)
    print('current kernel count: ', current_kernel_cnt)

    if current_kernel_cnt <= previous_kernel_cnt:
        print('Iteration done !')
        break
        
    print('hahaha4')

    previous_kernel_cnt = current_kernel_cnt

得到类似的错误:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-8-85a1837efd42> in <module>
     30 
     31 
---> 32     current_kernel_cnt = new_song_group_kernel.count()
     33 
     34     print('hahaha2')

/opt/tdw/spark-2.4.6/python/pyspark/sql/dataframe.py in count(self)
    522         2
    523         """
--> 524         return int(self._jdf.count())
    525 
    526     @ignore_unicode_prefix

/opt/tdw/spark-2.4.6/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

/opt/tdw/spark-2.4.6/python/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

/opt/tdw/spark-2.4.6/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling o676.count.
: java.lang.OutOfMemoryError: Java heap space
    at java.util.Arrays.copyOf(Arrays.java:3332)
    at java.lang.AbstractStringBuilder.ensureCapacityInternal(AbstractStringBuilder.java:124)
    at java.lang.AbstractStringBuilder.append(AbstractStringBuilder.java:448)
    at java.lang.StringBuilder.append(StringBuilder.java:136)
    at java.lang.StringBuilder.append(StringBuilder.java:131)
    at scala.StringContext.standardInterpolator(StringContext.scala:125)
    at scala.StringContext.s(StringContext.scala:95)
    at org.apache.spark.sql.execution.QueryExecution.toString(QueryExecution.scala:251)
    at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:87)
    at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:159)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:79)
    at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3369)
    at org.apache.spark.sql.Dataset.count(Dataset.scala:2840)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)

我找出问题所在。 这个错误与 spark DAG 的机制有关,它使用 DAG lineage 来跟踪一系列转换,当算法需要迭代时,lineage 会快速增长并达到内存的限制。 所以在实现迭代算法时打破血统是必要的。

主要有2种方式: 1.添加检查点。 2.重新创建数据框。 我在下面修改了我的代码,它只是添加了一个检查点来打破血统并为我工作。

epoch_cnt = 0
while True:
    print('hahaha1')
    
    print('cached df', len(spark.sparkContext._jsc.getPersistentRDDs().items()))
    
    singer_pairs_undirected_ungrouped = singer_pairs_undirected.join(old_song_group_kernel,
                                                                     on=singer_pairs_undirected['src'] == old_song_group_kernel['id'],
                                                                     how='left').filter(F.col('id').isNull()) \
        .select('src', 'dst')

    windowSpec = Window.partitionBy("src").orderBy(F.col("song_group_id_cnt").desc())

    singer_pairs_vote = singer_pairs_undirected_ungrouped.join(old_song_group_kernel,
                                                               on=singer_pairs_undirected_ungrouped['dst'] ==
                                                                  old_song_group_kernel['id'], how='inner') \
        .groupBy('src', 'song_group_id') \
        .agg(F.count('song_group_id').alias('song_group_id_cnt')) \
        .withColumn('song_group_id_cnt_rnk', F.row_number().over(windowSpec)) \
        .filter(F.col('song_group_id_cnt_rnk') == 1)

    singer_pairs_vote_output = singer_pairs_vote.select('src', 'song_group_id') \
        .withColumnRenamed('src', 'id')
    
    print('hahaha5')

    new_song_group_kernel = old_song_group_kernel.union(singer_pairs_vote_output) \
        .select('id', 'song_group_id').dropDuplicates().persist().checkpoint()
    
    print('hahaha9')
    
    current_kernel_cnt = new_song_group_kernel.count()
    
    print('hahaha2')
    old_song_group_kernel.unpersist()
    print('hahaha3')
    old_song_group_kernel = new_song_group_kernel

    epoch_cnt += 1
    print('epoch rounds: ', epoch_cnt)
    print('previous kernel count: ', previous_kernel_cnt)
    print('current kernel count: ', current_kernel_cnt)

    if current_kernel_cnt <= previous_kernel_cnt:
        print('Iteration done !')
        break
        
    print('hahaha4')

    previous_kernel_cnt = current_kernel_cnt

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM