简体   繁体   中英

Pyspark Java.lang.OutOfMemoryError: Java heap space

I am solving a problem using spark running in my local machine.

I am reading a parquet file from the local disk and storing it to the dataframe.

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

spark = SparkSession.builder\
    .config("spark.driver.memory","4g")\
    .config("spark.executor.memory","4g")\
    .config("spark.driver.maxResultSize","2g")\
    .getOrCreate()

content = spark.read.parquet('./files/file')

So, Content Dataframe contents around 500k rows ie

+-----------+----------+
|EMPLOYEE_ID|MANAGER_ID|
+-----------+----------+
|        100|         0|
|        101|       100|
|        102|       100|
|        103|       100|
|        104|       100|
|        105|       100|
|        106|       101|
|        101|       101|
|        101|       101|
|        101|       101|
|        101|       102|
|        101|       102|
       .           .
       .           .
       .           .

I write this code to provide each EMPLOYEE_ID an EMPLOYEE_LEVEL according to their hierarchy.

# Assign EMPLOYEE_LEVEL 1 WHEN MANAGER_ID is 0 ELSE NULL
content_df = content.withColumn("EMPLOYEE_LEVEL", when(col("MANAGER_ID") == 0, 1).otherwise(lit('')))

level_df = content_df.select("*").filter("Level = 1")

level = 1
while True:
    ldf = level_df
    temp_df = content_df.join(
        ldf,
        ((ldf["EMPLOYEE_LEVEL"] == level) & 
         (ldf["EMPLOYEE_ID"] == content_df["MANAGER_ID"])),
        "left") \
    .withColumn("EMPLOYEE_LEVEL",ldf["EMPLOYEE_LEVEL"]+1)\
    .select("EMPLOYEE_ID","MANAGER_ID","EMPLOYEE_LEVEL")\
    .filter("EMPLOYEE_LEVEL IS NOT NULL")\
    .distinct()
    
    if temp_df.count() == 0:
        break
    level_df = level_df.union(temp_df)
    level += 1

It's running, but very slow execution and after some period of time it gives this error.

Py4JJavaError: An error occurred while calling o383.count.
: java.lang.OutOfMemoryError: Java heap space
    at scala.collection.immutable.List.$colon$colon(List.scala:117)
    at scala.collection.immutable.List.$plus$colon(List.scala:220)
    at org.apache.spark.sql.catalyst.expressions.String2TrimExpression.children(stringExpressions.scala:816)
    at org.apache.spark.sql.catalyst.expressions.String2TrimExpression.children$(stringExpressions.scala:816)
    at org.apache.spark.sql.catalyst.expressions.StringTrim.children(stringExpressions.scala:948)
    at org.apache.spark.sql.catalyst.trees.TreeNode.withNewChildren(TreeNode.scala:351)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:595)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:486)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:486)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$Lambda$1822/0x0000000100d21040.apply(Unknown Source)
    at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
    at scala.collection.TraversableLike$$Lambda$61/0x00000001001d2040.apply(Unknown Source)
    at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
    at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
    at scala.collection.TraversableLike.map(TraversableLike.scala:286)
    at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
    at scala.collection.AbstractTraversable.map(Traversable.scala:108)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:595)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:486)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:486)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$Lambda$1822/0x0000000100d21040.apply(Unknown Source)
    at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren(TreeNode.scala:1148)
    at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren$(TreeNode.scala:1147)
    at org.apache.spark.sql.catalyst.expressions.BinaryExpression.mapChildren(Expression.scala:555)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:486)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:486)
    at org.apache.spark.sql.catalyst.trees.TreeNode$$Lambda$1822/0x0000000100d21040.apply(Unknown Source)
    at org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren(TreeNode.scala:1122)
    at org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren$(TreeNode.scala:1121)
    at org.apache.spark.sql.catalyst.expressions.UnaryExpression.mapChildren(Expression.scala:467)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:486)

I tried many solutions including increasing driver and executor memory, using cache() and persist() for dataframe also doesn't worked for me.

I am using Spark 3.2.1 Spark

Any help will be appreciated. Thank you.

One reason it is slow is that in your codes it loop through for many iterations, so the computation linage of dataframe get increased every loop, using persist/cache can improve the efficiency. But it is not the reason for the error.

I write very similar codes like you using pyspark, and get the same problems. Do you already have solution, bro?

Put my codes below:

epoch_cnt = 0
while True:
    print('hahaha1')
    
    print('cached df', len(spark.sparkContext._jsc.getPersistentRDDs().items()))
    
    singer_pairs_undirected_ungrouped = singer_pairs_undirected.join(old_song_group_kernel,
                                                                     on=singer_pairs_undirected['src'] ==
                                                                        old_song_group_kernel['id'],
                                                                     how='left').filter(F.col('id').isNull()) \
        .select('src', 'dst')

    windowSpec = Window.partitionBy("src").orderBy(F.col("song_group_id_cnt").desc())

    singer_pairs_vote = singer_pairs_undirected_ungrouped.join(old_song_group_kernel,
                                                               on=singer_pairs_undirected_ungrouped['dst'] ==
                                                                  old_song_group_kernel['id'], how='inner') \
        .groupBy('src', 'song_group_id') \
        .agg(F.count('song_group_id').alias('song_group_id_cnt')) \
        .withColumn('song_group_id_cnt_rnk', F.row_number().over(windowSpec)) \
        .filter(F.col('song_group_id_cnt_rnk') == 1)

    singer_pairs_vote_output = singer_pairs_vote.select('src', 'song_group_id') \
        .withColumnRenamed('src', 'id')
    
    print('hahaha5')

    new_song_group_kernel = old_song_group_kernel.union(singer_pairs_vote_output) \
        .select('id', 'song_group_id').dropDuplicates().persist()
    
    print('hahaha9')
    
    current_kernel_cnt = new_song_group_kernel.count()
    
    print('hahaha2')
    old_song_group_kernel.unpersist(True)
    print('hahaha3')
    old_song_group_kernel = new_song_group_kernel

    epoch_cnt += 1
    print('epoch rounds: ', epoch_cnt)
    print('previous kernel count: ', previous_kernel_cnt)
    print('current kernel count: ', current_kernel_cnt)

    if current_kernel_cnt <= previous_kernel_cnt:
        print('Iteration done !')
        break
        
    print('hahaha4')

    previous_kernel_cnt = current_kernel_cnt

Get similar errors:

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-8-85a1837efd42> in <module>
     30 
     31 
---> 32     current_kernel_cnt = new_song_group_kernel.count()
     33 
     34     print('hahaha2')

/opt/tdw/spark-2.4.6/python/pyspark/sql/dataframe.py in count(self)
    522         2
    523         """
--> 524         return int(self._jdf.count())
    525 
    526     @ignore_unicode_prefix

/opt/tdw/spark-2.4.6/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1255         answer = self.gateway_client.send_command(command)
   1256         return_value = get_return_value(
-> 1257             answer, self.gateway_client, self.target_id, self.name)
   1258 
   1259         for temp_arg in temp_args:

/opt/tdw/spark-2.4.6/python/pyspark/sql/utils.py in deco(*a, **kw)
     61     def deco(*a, **kw):
     62         try:
---> 63             return f(*a, **kw)
     64         except py4j.protocol.Py4JJavaError as e:
     65             s = e.java_exception.toString()

/opt/tdw/spark-2.4.6/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling o676.count.
: java.lang.OutOfMemoryError: Java heap space
    at java.util.Arrays.copyOf(Arrays.java:3332)
    at java.lang.AbstractStringBuilder.ensureCapacityInternal(AbstractStringBuilder.java:124)
    at java.lang.AbstractStringBuilder.append(AbstractStringBuilder.java:448)
    at java.lang.StringBuilder.append(StringBuilder.java:136)
    at java.lang.StringBuilder.append(StringBuilder.java:131)
    at scala.StringContext.standardInterpolator(StringContext.scala:125)
    at scala.StringContext.s(StringContext.scala:95)
    at org.apache.spark.sql.execution.QueryExecution.toString(QueryExecution.scala:251)
    at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:87)
    at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:159)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:79)
    at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3369)
    at org.apache.spark.sql.Dataset.count(Dataset.scala:2840)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.lang.Thread.run(Thread.java:748)

I figure out the problem. This error related to the mechanism of spark DAG, it use DAG lineage to track a series transformations, when the algorithms need to iterate, the lineage can grow fast and hit the limitation of memory. So break the lineage is necessary when implementing iteration algorithms.

There are mainly 2 ways: 1. add checkpoint. 2.recreate dataframe. I modify my codes below, which just add a checkpoint to break the lineage and works for me.

epoch_cnt = 0
while True:
    print('hahaha1')
    
    print('cached df', len(spark.sparkContext._jsc.getPersistentRDDs().items()))
    
    singer_pairs_undirected_ungrouped = singer_pairs_undirected.join(old_song_group_kernel,
                                                                     on=singer_pairs_undirected['src'] == old_song_group_kernel['id'],
                                                                     how='left').filter(F.col('id').isNull()) \
        .select('src', 'dst')

    windowSpec = Window.partitionBy("src").orderBy(F.col("song_group_id_cnt").desc())

    singer_pairs_vote = singer_pairs_undirected_ungrouped.join(old_song_group_kernel,
                                                               on=singer_pairs_undirected_ungrouped['dst'] ==
                                                                  old_song_group_kernel['id'], how='inner') \
        .groupBy('src', 'song_group_id') \
        .agg(F.count('song_group_id').alias('song_group_id_cnt')) \
        .withColumn('song_group_id_cnt_rnk', F.row_number().over(windowSpec)) \
        .filter(F.col('song_group_id_cnt_rnk') == 1)

    singer_pairs_vote_output = singer_pairs_vote.select('src', 'song_group_id') \
        .withColumnRenamed('src', 'id')
    
    print('hahaha5')

    new_song_group_kernel = old_song_group_kernel.union(singer_pairs_vote_output) \
        .select('id', 'song_group_id').dropDuplicates().persist().checkpoint()
    
    print('hahaha9')
    
    current_kernel_cnt = new_song_group_kernel.count()
    
    print('hahaha2')
    old_song_group_kernel.unpersist()
    print('hahaha3')
    old_song_group_kernel = new_song_group_kernel

    epoch_cnt += 1
    print('epoch rounds: ', epoch_cnt)
    print('previous kernel count: ', previous_kernel_cnt)
    print('current kernel count: ', current_kernel_cnt)

    if current_kernel_cnt <= previous_kernel_cnt:
        print('Iteration done !')
        break
        
    print('hahaha4')

    previous_kernel_cnt = current_kernel_cnt

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM