简体   繁体   中英

Databricks/python - what is a best practice approach to create a robust long running job

I couldn't find a good overview how to create a job that has a moderate possibility of failure.

I am an experienced developer, but I am relatively new to databricks/spark. While I can program my way out of my problem, I'm looking for a best practice solution.

My scenario is reading a large number of rows out of a web API. The job takes about 36 hours to run. During these 36 hours, there is a high probability that I will encounter a fatal error while interacting with the API (Timeouts, disconnects while reading, invalid/unexpected return values, etc.). While I can increasingly make my job robust to these errors, ideally, I will not have to run the entire job again to recover. Ideally, I only need to run the failed cases.

My basic flow is like this:

  • Read in a curated set of IDs (100's thousands)
  • For each ID, call the web API to get details
  • Write the resulting output into a new table (ID + Details)

Approaches I have evaluated:

  1. Attempt to capture all errors in a blanket fasion and output failures into the resulting table. Recovery is then to read the failed rows as a source of IDs after patching whatever caused the failure.
  2. Partition the initial dataset into multiple files and cobble together something that schedules work on individual partitions. Then re-run a single partition if one of the items in it fails. After all succeed, aggregate the results. I think this is doable but with my limited understanding of databricks it looks pretty messy. I'd do my own partitioning and task scheduling. I'm hoping there is a better way.

The solution I imagine in my head is something like:

# Split the source table into 100 equal buckets
# Run only buckets 10,20,21 (presumably, those are the failed buckets)
# For each bucket, run the udf get_details
# If the bucket succeeds, put it's rows into aggregate_df.  Otherwise, into error_df
aggregate_df, error_df = df.split_table_evenly(bucket_count=100)
  .options(continue_on_task_failure=true)
  .filter(bucket=[10,20,21])
  .run_task_on_bucket(udf=get_details)

The solution to this is to use Spark Structured Streaming which supports checkpointing of streaming queries. The Spark guide gives a very thorough description of how to use Structured Streaming in different scenarios. It doesn't explicitly cover my use case - breaking up a dataset by rows - and I will describe it here.

The basic approach is to break up the stream into what Spark terms as 'micro-batches'. For the scenario above, the important thing to understand is that spark batches by time whereas I want to batch by rows . Spark has one provider that kind-of supports batching by data - the Kafka provider which can batch based on offset. Since I don't want to run my data through Kafka, I chose not to use this approach.

The file source does have one tool we can use : it has the ability to set a limit on the number of files in a batch using the maxFilesPerTrigger option. I also use latestFirst to process the oldest files first, but this is not required.

source_dataframe = self.sparksession.readStream.option('maxFilesPerTrigger', 1) \
.option('latestFirst', True) \
.format('delta') \
.load(path)

Since this only works for an entire file, I needed to limit the size of the file at generation time. To do this, I just partitioned my dataset using a key that would generate a comfortable bucket size.

Because having many files is not super efficient for most uses, I chose to write this dataset twice so don't pay the cost to read many small files unless I am checkpointing. However, this is totally optional.

dataframe = # Some query
curated_output_dataframe = dataframe
generate_job_input_dataframe = dataframe.partitionBy('somecolumn')

curated_output_dataframe.write.format(my_format).save(path=my_curated_output_path)
generate_job_data_dataframe.write.format(my_format).save(path=my_job_data_path)

Then, you can read and write the dataset using streaming functions and you are almost good to go.

source_dataframe = self.sparksession.readStream.option('maxFilesPerTrigger', 1).format(input_path).load(input_path)
query = dataframe.writeStream.format(output_path).start('output_path')
query.awaitTermination()

There are a few more things to take care of before you can do this in a robust fashion however.

You need to set a checkpoint location if you want to be able to recover your job.

sparkSession.readStream.option('checkpointLocation','/_checkpoints/some_unique_directory')

Spark will be greedy and read all your input files if you don't change some settings.
See this answer for additional details.

You should estimate your rate and set the starting target rate to that rate. I set mine very low to start with - spark will adjust from this rate to whatever it things it is sustainable.

sparksession.conf.set('spark.streaming.backpressure.enabled', True)
sparksession.conf.set('spark.streaming.backpressure.initialRate', target_rate_per_second)
sparksession.conf.set('spark.streaming.backpressure.rateEstimator', 'pid')
sparksession.conf.set('spark.streaming.backpressure.pid.minRate', 1)

Spark will run forever unless you monitor the job. Spark assumes streaming jobs run indefinitely and there is no explicit support for running until dataset exhaustion. You have to code this yourself and unfortunately the code is fragile because you have to inspect the job status message in some cases.

See this answer for more details. The code in the answer didn't 100% work for me - I found more cases in the message expression. This is the code I am currently using (with logging and comments stripped):

while query.isActive:
    msg = query.status['message']
    data_avail = query.status['isDataAvailable']
    trigger_active = query.status['isTriggerActive']
    if not data_avail and not trigger_active:
        if 'Initializing' not in msg:
            query.stop()
    time.sleep(poll_interval_seconds)

Other items of note These aren't required to checkpoint micro-batches but are otherwise useful.

Spark produces logs that are useful - these are the two I found useful when building trust in streaming execution:

22/05/18 23:22:32 INFO MicroBatchExecution: Streaming query made progress:
22/05/18 23:24:44 WARN ProcessingTimeExecutor: Current batch is falling behind. The trigger interval is 500 milliseconds, but spent 5593 milliseconds

Logging for the PID estimator can be enabled via

sparksession.conf.set('log4j.logger.org.apache.spark.streaming.scheduler.rate.PIDRateEstimator', 'TRACE')

Spark has recently added RocksDB support for streaming state management , which you need to enable explicitly. It has worked smoothly for me.

sparksession.conf.set('spark.sql.streaming.stateStore.providerClass', \                                    
    'org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider')

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM