简体   繁体   中英

How can I apply aggregate functions element-wise over arrays in PostgreSQL, e.g. weighted array sums over a group?

I have a table such as the following (see db<>fiddle ):

grp n vals
0 2 {1,2,3,4}
1 5 {3,2,1,2}
1 3 {0,5,4,3}

where for each group (defined by grp ) I would like to perform some arithmetic involving the group's scalars n and arrays vals . I'm interested in a kind of weighted sum, such that each row's vals are multiplied by its n , and the resulting arrays are summed element-wise within each group, outputting one array per group:

grp result
0 {2,4,6,8}
1 {15,25,17,19}

Here's what I've tried. This fails with an error ( aggregate function calls cannot contain set-returning function calls ):

SELECT
    grp,
    ARRAY(SELECT SUM(n * UNNEST(vals)))
FROM
    tbl
GROUP BY
    grp

The error includes a hint, but I am unable to make sense of it for my use case.

The following sums the desired arrays down to scalars:

SELECT
    grp,
    SUM(n * vals[i])
FROM
    tbl,
    generate_series(1, 4) i
GROUP BY
    grp

Only this sort of works:

SELECT
    grp,
    SUM(n * vals[1]),
    SUM(n * vals[2]),
    SUM(n * vals[3]),
    SUM(n * vals[4])
FROM
    tbl
GROUP BY
    grp

but it doesn't result in an array, and it involves writing out each element of the array separately. In my case the arrays are much longer than four elements, so this is too awkward.

WITH flattened AS (
    SELECT grp, position, SUM(val * n) AS s
    FROM tbl, unnest(vals) WITH ORDINALITY AS f(val, position)
    GROUP BY grp, position
    ORDER BY grp, position
)
SELECT grp, array_agg(s ORDER BY position)
FROM flattened
GROUP BY grp
;

+---+-------------------------------------------------------------------------------------+
|grp|array_agg                                                                            |
+---+-------------------------------------------------------------------------------------+
|0  |{2.00000000000000000,4.00000000000000000,6.00000000000000000,8.00000000000000000}    |
|1  |{15.00000000000000000,25.00000000000000000,17.00000000000000000,19.00000000000000000}|
+---+-------------------------------------------------------------------------------------+

Explanation:

You can use UNNEST... WITH ORDINALITY to keep track of the position of each value:

SELECT grp, position, val, n
FROM tbl, unnest(vals) WITH ORDINALITY AS f(val, position);

+---+--------+---+-+
|grp|position|val|n|
+---+--------+---+-+
|0  |1       |1  |2|
|0  |2       |2  |2|
|0  |3       |3  |2|
|0  |4       |4  |2|
|1  |1       |3  |5|
|1  |2       |2  |5|
|1  |3       |1  |5|
|1  |4       |2  |5|
|1  |1       |0  |3|
|1  |2       |5  |3|
|1  |3       |4  |3|
|1  |4       |3  |3|
+---+--------+---+-+

Then GROUP BY the original group and each position:

SELECT grp, position, SUM(val * n) AS s
FROM tbl, unnest(vals) WITH ORDINALITY AS f(val, position)
GROUP BY grp, position
ORDER BY grp, position;

+---+--------+--+
|grp|position|s |
+---+--------+--+
|0  |1       |2 |
|0  |2       |4 |
|0  |3       |6 |
|0  |4       |8 |
|1  |1       |15|
|1  |2       |25|
|1  |3       |17|
|1  |4       |19|
+---+--------+--+

Then you only need the ARRAY_AGG as in the answer.

I would write functions for that, otherwise the SQL will get really messy.

One function to multiply all elements with a given value:

create function array_mul(p_input real[], p_mul int)
  returns real[]
as
$$
  select array(select i * p_mul
               from unnest(p_input) with ordinality as t(i,idx)
               order by idx);
$$
language sql
immutable;

And one function to be used as an aggregate that sums up the elements with the same index:

create or replace function array_add(p_one real[], p_two real[])
  returns real[]
as
$$
declare
  l_idx int;
  l_result real[];
begin
  if p_one is null or p_two is null then
    return coalesce(p_one, p_two);
  end if;
  
  for l_idx in 1..greatest(cardinality(p_one), cardinality(p_two)) loop
    l_result[l_idx] := coalesce(p_one[l_idx],0) + coalesce(p_two[l_idx], 0);
  end loop;
  
  return l_result;  
end;  
$$
language plpgsql
immutable;

That can be used to define a custom aggregate:

create aggregate array_element_sum(real[]) (
  sfunc = array_add,
  stype = real[],
  initcond = '{}'
);

And then your query is as simple as:

select grp, array_element_sum(array_mul(vals, n))
from tbl
group by grp;

Online example

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM