简体   繁体   中英

Efficient rolling sum (window aggregate) in SAS

I have two tables:

  • tb_payments : contract_id, payment_date, payment_value
  • tb_reference : contract_id, reference_date

For each (contract_id, reference_date) in tb_reference, I want to create a column sum_payments as the 90 days rolling sum from tb_payments. I can accomplish this (very inefficiently) with the query below:

%let window=90;
proc sql;
    create index contract_id on tb_payments;
quit;
proc sql;
    create table tb_rolling as
    select a.contract_id,
           a.reference_date,
           (select sum(b.payment_value)
            from tb_payments as b
            where a.contract_id = b.contract_id
                  and a.reference_date - &window. < b.payment_date
                  and b.payment_date <= a.reference_date
           ) as sum_payments
    from tb_reference as a;
quit;

How can I rewrite this to reduce the time complexity, using proc sql or SAS data step?

Edit with more info:

  • I chose 90 days as the window arbitrarily, but I will perform calculations for several windows. A solution that can perform calculations for several windows at the same time would be ideal
  • Both tables can have 10+ millions of rows, and data is completely arbitrary. My SAS server is quite powerful though
  • Contract_ids can be repeated in both tables
  • The pairs (contract_id, reference_date) and (contract_id, payment_date) are unique

Edit with sample data:

%let seed=1111;
data tb_reference (drop=i);
    call streaminit(&seed.);
    do i = 1 to 10000;
        contract_id = round(rand('UNIFORM')*1000000,1);
        output;
    end;
run;
proc surveyselect data=tb_reference out=tb_payments n=5000 seed=&seed.; run;
data tb_reference(drop=i);
    format reference_date date9.;
    call streaminit(&seed.);
    set tb_reference;
    do i = 1 to 1+round(rand('UNIFORM')*4,1);
        reference_date = '01jan2016'd + round(rand('UNIFORM')*1000,1);
        output;
    end;
run;
proc sort data=tb_reference nodupkey; by contract_id reference_date; run;
data tb_payments(drop=i);
    format payment_date date9. payment_value comma20.2;
    call streaminit(&seed.);
    set tb_payments;
    do i = 1 to 1+round(rand('UNIFORM')*20,1);
        payment_date = '01jan2015'd + round(rand('UNIFORM')*1365,1);
        payment_value = round(rand('UNIFORM')*3333,0.01);
        output;
    end;
run;
proc sort data=tb_payments nodupkey; by contract_id payment_date; run;

Update: I compared my naive solution to both proposals from Quentin and Tom.

  • The merge method is quite fast and achieved over 10x speedup for n=10000. It is also very powerful, as beautifully demonstrated by Tom in his answer.
  • Hash tables are insanely fast and achieved over 500x speedup. Because my datasets are large, this is the way to go, but there's a catch: they need to fit in RAM.

If anyone needs the full testing code, feel free to send me a message.

It probably is possible to do this all with PROC EXPAND if you have it licensed. But let's look at how to do it without that.

It shouldn't be that hard if all of the dates are present in the PAYMENTS table. Just merge the two tables by ID and DATE. Calculate the running sum, but with the wrinkle of also subtracting out the value that is rolling out the back of the window. Then just keep the dates that are in the reference file.

One issue might be the need to find all possible dates for a CONTRACT_ID so that LAG() function can be used. That is easy to do with PROC MEANS.

proc summary data=tb_payments nway ;
  by contract_id ;
  var payment_date;
  output out=tb_id_dates(drop=_:) min=date1 max=date2 ;
run;

And a data step. This step could also be a view instead.

data tb_id_dates_all ;
  set tb_id_dates ;
  do date=date1 to date2 ;
    output;
  end;
  format date date9.;
  keep contract_id date ;
run;

Now just merge the three datasets and calculate the cumulative sums. Note that I included a do loop to accumulate multiple payments on a single day (remove the nodupkey in your sample data generation code to test it).

If you want to generate multiple windows then you will need multiple actual LAG() function calls.

data want ;
  do until (last.contract_id);
    do until (last.date);
      merge tb_id_dates_all tb_payments(rename=(payment_date=date))
            tb_reference(rename=(reference_date=date) in=in2)
      ;
      by contract_id date ;
      payment=sum(0,payment,payment_value);
    end;
    day_num=sum(day_num,1);

    array lag_days(5) _temporary_ (7 30 60 90 180) ;
    array lag_payment(5) _temporary_ ;
    array cumm(5) cumm_7 cumm_30 cumm_60 cumm_90 cumm_180 ;
    lag_payment(1) = lag7(payment);
    lag_payment(2) = lag30(payment);
    lag_payment(3) = lag60(payment);
    lag_payment(4) = lag90(payment);
    lag_payment(5) = lag180(payment);

    do i=1 to dim(cumm) ;
       cumm(i)=sum(cumm(i),payment);
       if day_num > lag_days(i) then cumm(i)=sum(cumm(i),-lag_payment(i));
       if .z < abs(cumm(i)) < 1e-5 then cumm(i)=0;
    end;
    if in2 then output ;
  end;
  keep contract_id date cumm_: ;
  format cumm_: comma20.2 ;
  rename date=reference_date ;
run;

If you want to make the code flexible for the number of windows you will need to add some code generation to create the LAGxx() function calls. For example you could use this macro:

%macro lags(windows);
%local i n lag ;
%let n=%sysfunc(countw(&windows));

array lag_days(&n) _temporary_ (&windows) ;
array lag_payment(&n) _temporary_ ;
array cumm(&n)
%do i=1 %to &n ;
  %let lag=%scan(&windows,&i);
 cumm_&lag
%end;
;
%do i=1 %to &n ;
  %let lag=%scan(&windows,&i);
lag_payment(&i) = lag&lag(payment);
%end;
%mend lags;

And replace the ARRAY and assignment statements with LAGxx() functions with this call to the macro:

%lags(7 30 60 90 180)

Here's an example of a hash approach. Since your data are already sorted, I don't think there is much benefit to the hash approach over Tom's merge approach.

General idea is to read all of the payment data into a hash table (you may run out of memory if your real data is too big), then read through the data set of reference dates. For each reference date, you look up all of the payments for that contract_id, and iterate through them, testing to see if payment date is <90 days before the reference_date, and conditionally incrementing sum_payments.

Should be noticeably faster than the SQL approach in your question, but could lose to the MERGE approach. If the data were not sorted in advance, this might beat the time for sorting both big datasets and then merging. It could handle multiple payments on the same date.

data want;
  *initialize variables for hash table ;
  call missing(payment_date,payment_value) ;

  *Load a hash table with all of the payment data ;
  if _n_=1 then do ;
    declare hash h(dataset:"tb_payments", multidata: "yes");
    h.defineKey("contract_ID");
    h.defineData("payment_date","payment_value");
    h.defineDone() ;
  end ;

  *read in the reference dates ;
  set tb_reference (keep=contract_id reference_date) ;

  *for each reference date, look up all the payments for that contract_id ;
  *and iterate through them.  If the payment date is < 90 days before reference date then ;
  *increment sum_payments ;

  sum_payments=0 ;
  rc=h.find();  
  do while (rc = 0); *found a record;
    if 0<=(reference_date-payment_date)<90 then sum_payments = sum_payments + payment_value ;
    rc=h.find_next();
  end;
run ;

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM