简体   繁体   中英

Spark performance issue (likely caused by “basic” mistakes)

I'm relatively new to Apache Spark (version 1.6), and I feel I hit a wall: I looked through most of the Spark-related question on SE, but I found nothing that helped me so far. I believe I am doing something fundamentally wrong at the basic level, but I cannot point out what exactly it is, especially since other pieces of code I've written are running just fine.

I'll try to be as specific as possible in explaining my situation, although I'll simplify my task for better understanding. Keep in mind that, as I am still learning it, I'm running this code using Spark's local mode ; also worthy of note is that I've been using DataFrames (and not RDDs). Lastly, do note that the following code is written in Python using Pyspark, but I do welcome possible solutions using Scala or Java, as I believe the issue is a very basic one.

I have a generic JSON file, with its structure resembling the following:

{"events":[ 
    {"Person":"Alex","Shop":"Burger King","Timestamp":"100"},
    {"Person":"Alex","Shop":"McDonalds","Timestamp":"101"},
    {"Person":"Alex","Shop":"McDonalds","Timestamp":"104"},
    {"Person":"Nathan","Shop":"KFC","Timestamp":"100"},
    {"Person":"Nathan","Shop":"KFC","Timestamp":"120"},
    {"Person":"Nathan","Shop":"Burger King","Timestamp":"170"}]}

What I need to do, is count how much time has passed between two visits by the same person to the same shop. The output should be the list of shops which have had at least one customer visit them at least once every 5 seconds, alongside the number of customers that meet this requirement. In the case above, the output should look something like this:

{"Shop":"McDonalds","PeopleCount":1}

My idea was to assign to each pair (Person, Shop) the same identifier, and then proceed to verify if that pair met the requirement. The identifier can be assigned by using the window function ROW_NUMBER() , which requires the use of hiveContext in Spark. This is how the file above should look like after the identifier has been assigned:

{"events":[ 
    {"Person":"Alex","Shop":"Burger King","Timestamp":"100","ID":1},
    {"Person":"Alex","Shop":"McDonalds","Timestamp":"101", "ID":2},
    {"Person":"Alex","Shop":"McDonalds","Timestamp":"104", "ID":2},
    {"Person":"Nathan","Shop":"KFC","Timestamp":"100","ID":3},
    {"Person":"Nathan","Shop":"KFC","Timestamp":"120","ID":3},
    {"Person":"Nathan","Shop":"Burger King","Timestamp":"170","ID":4}]}

As I need to perform several steps (some of these requiring the use of self joins ) for each pair before coming to a conclusion, I made use of temporary tables.

The code I've written is something like this (of course, I have included only the relevant parts - "df" stands for "dataframe"):

t1_df = hiveContext.read.json(inputFileName)
t1_df.registerTempTable("events")
t2_df = hiveContext.sql("SELECT Person, Shop, ROW_NUMBER() OVER (order by Person asc, Shop asc) as ID FROM events group by Person, Shop HAVING count(*)>1") #if there are less than 2 entries for the same pair, then we can discard this pair
t2_df.write.mode("overwrite").saveAsTable("orderedIDs")
n_pairs = t2_df.count() #used to determine how many pairs I need to inspect
i=1
while i<=n_pairs:
    #now I perform several operations, each one displaying this structure
    #first operation...
    query="SELECT ... FROM orderedIDs WHERE ID=%d" %i
    t3_df = hiveContext.sql(query)
    t3_df.write.mode("overwrite").saveAsTable("table1")
    #...second operation...
    query2="SELECT ... FROM table1 WHERE ..."
    t4_df = hiveContext.sql(query2)
    temp3_df.write.mode("overwrite").saveAsTable("table2")
    #...and so on. Let us skip to the last operation in this loop, which consists of the "saving" of the shop if it met the requirements:
    t8_df = hiveContext.sql("SELECT Shop from table7")
    t8_df.write.mode("append").saveAsTable("goodShops")
    i=i+1

#then we only need to write the table to a proper file
output_df = hiveContext.sql("SELECT Shop, count(*) as PeopleCount from goodShops group by Shop")
output_df.write.json('output')

Now, here come the issues: the output is the correct one. I've tried with several inputs, and the program works fine, in that regard. However, it is tremendously slow: it takes around 15-20 seconds to analyze each pair, regardless of the entries each pair has. So, for example, if there are 10 pairs it takes around 3 minutes, if there are 100 it takes 30 minutes, and so on. I ran this code on several machines with relatively decent hardware, but nothing changed. I also tried caching some (or even all) of the tables I used, but the problem still persisted (the time required even increased in certain circumstances). More specifically, the last operation of the loop (the one which uses the "append") takes several seconds to complete (from 5 to 10), whereas the first 6 only take 1-2 seconds (which is still a lot, considering the scope of the task, but definitely more manageable).

I believe the issue may lie in one (or more) of the following:

  1. use of a loop, which might cause problems of parallelism;
  2. use of the "saveAsTable" method, which forces writing to I/O
  3. bad or poor use of caching

These 3 are the only things that come to my mind, as the other pieces of software I've written using Spark (for which I did not encounter any performance issues) do not make use of the abovementioned techniques, since I basically performed simple JOIN operations and used the registerTempTable method for using temporary tables (which, to my understanding, cannot be used in a loop ) instead of the saveAsTable method.

I tried to be as clear as possible, but if you do require more details I am up for providing additional information.

EDIT: I managed to solve my issue thanks to zero323's answer. Indeed, the use of the LAG function was all I really needed to do my stuff. On the other hand, I've learnt that using the "saveAsTable" method should be discouraged - especially in loops - as it causes a major decrease in performance every time it is called. I'll avoid using it from now on unless it is absolutely necessary.

how much time has passed between two visits by the same person to the same shop. The output should be the list of shops which have had at least one customer visit them at least once every 5 seconds, alongside the number of customers that meet this requirement.

How about simple lag with aggregation:

from pyspark.sql.window import Window
from pyspark.sql.functions import col, lag, sum

df = (sc
    .parallelize([
        ("Alex", "Burger King", "100"), ("Alex", "McDonalds", "101"),
        ("Alex", "McDonalds", "104"), ("Nathan", "KFC", "100"),
        ("Nathan", "KFC", "120"), ("Nathan", "Burger King", "170")
    ]).toDF(["Person", "Shop", "Timestamp"])
    .withColumn("Timestamp", col("timestamp").cast("long")))

w = (Window()
    .partitionBy("Person", "Shop")
    .orderBy("timestamp"))

ind = ((
    # Difference between current and previous timestamp le 5
    col("Timestamp") - lag("Timestamp", 1).over(w)) <= 5
 ).cast("long") # Cast so we can sum

(df
    .withColumn("ind", ind)
    .groupBy("Shop")
    .agg(sum("ind").alias("events"))
    .where(col("events") > 0)
    .show())

## +---------+------+
## |     Shop|events|
## +---------+------+
## |McDonalds|     1|
## +---------+------+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM