简体   繁体   中英

conditional grouping by value and number of rows in R

Original data:

> dt = data.table(v1 = c(3,1,1,5,6,12,13,11,10,0,2,1,3))
> dt
    v1
 1:  3
 2:  1
 3:  1
 4:  5
 5:  6
 6: 12
 7: 13
 8: 11
 9: 10
10:  0
11:  2
12:  1
13:  3

I would like to put v1 into 3 groups based on value as follows:

> dt %>%  mutate(group = case_when(v1 <5 ~ 1,
+                               v1 >=5 & v1 <10 ~ 2,
+                               v1 >= 10 ~3))
   v1 group
1   3  1
2   1  1
3   1  1
4   5  2
5   6  2
6  12  3
7  13  3
8  11  3
9  10  3
10  0  1
11  2  1
12  1  1
13  3  1

But I would also like to add a rule where if the total number of rows in a group is under 3, it takes the mean of those rows, and compares it to the rows (of v1) immediately before and after that group, and whichever value is closest to the mean absorbs that group.

In the example above, group 2 only has 2 rows, so I take their mean (5.5) and compare to the value above (1) and below (12). Since the smaller value is closer to the mean, those rows become group 1, making the desired output look as follows:

   v1 group
1   3  1
2   1  1
3   1  1
4   5  1
5   6  1
6  12  3
7  13  3
8  11  3
9  10  3
10  0  1
11  2  1
12  1  1
13  3  1

I've made a few attempts to no avail and would really appreciate a dplyr or data.table solution.

One option using dplyr could be to create a new column which would keep an account of row_number and compare the v1 value of one row above and below of those groups which have less than 3 rows and assign the new groups based on it. Here change is the final output.

library(dplyr)
dt1 <- dt %>%  
         mutate(group = case_when(v1 < 5 ~ 1,
                                  v1 >=5 & v1 <10 ~ 2,
                                  v1 >= 10 ~3), 
                row = row_number())

dt1 %>%
   group_by(group) %>%
   mutate(change = if (n() < 3) {
     c(dt1$group[first(row) - 1L], dt1$group[last(row) + 1L])[
        which.min(c(abs(mean(v1) - dt1$v1[first(row) - 1L]),
                    abs(mean(v1) - dt1$v1[last(row) + 1L])))]
      }   else group) 


#     v1 group   row change
#   <dbl> <dbl> <int>  <dbl>
# 1     3     1     1      1
# 2     1     1     2      1
# 3     1     1     3      1
# 4     5     2     4      1
# 5     6     2     5      1
# 6    12     3     6      3
# 7    13     3     7      3
# 8    11     3     8      3
# 9    10     3     9      3
#10     0     1    10      1
#11     2     1    11      1
#12     1     1    12      1
#13     3     1    13      1

First, compute the original grouping and aggregate:

gDT = dt[, .(.N, m = mean(v1)), by=.(
  ct = ct <- cut(v1, c(-Inf, 5, 10, Inf), right=FALSE, labels=FALSE),
  g = rleid(ct)
)]

   ct g N         m
1:  1 1 3  1.666667
2:  2 2 2  5.500000
3:  3 3 4 11.500000
4:  1 4 4  1.500000

Flag groups to change and compare m with the nearest unchanging groups above and below:

gDT[, flag := N < 3]

gDT[, res := ct]
gDT[flag == TRUE, res := {
  ffDT = gDT[flag == FALSE]

  # nearest eligible rows going up and down -- possibly NA if at top or bottom
  w_dn = ffDT[.(g = .SD$g - 1L), on=.(g), roll=TRUE, which=TRUE]
  w_up = ffDT[.(g = .SD$g + 1L), on=.(g), roll=-Inf, which=TRUE]

  # diffs of the mean against eligible rows up and down
  diffs = lapply(list(dn = w_dn, up = w_up), function(w) abs(ffDT$m[w] - m))

  # if/else for whichever is nearer, ties broken in favor of up
  replace(ffDT$ct[w_dn], diffs$up < diffs$dn, ffDT$ct[w_up])
}]

   ct g N         m  flag res
1:  1 1 3  1.666667 FALSE   1
2:  2 2 2  5.500000  TRUE   1
3:  3 3 4 11.500000 FALSE   3
4:  1 4 4  1.500000 FALSE   1

Creating a separate table like this makes it easy to check your work (look at flagged groups, check N and ct , compare m with nearest unflagged neighbors, etc).

To add back to the original table, one way is:

dt[, res := gDT$res[ rleid(cut(v1, c(-Inf, 5, 10, Inf), right=FALSE, labels=FALSE)) ] ]

    v1 ct res
 1:  3  1   1
 2:  1  1   1
 3:  1  1   1
 4:  5  2   1
 5:  6  2   1
 6: 12  3   3
 7: 13  3   3
 8: 11  3   3
 9: 10  3   3
10:  0  1   1
11:  2  1   1
12:  1  1   1
13:  3  1   1

Details: The steps above are a lot more complicated than those in @RonakShah's answer since I assume that "group" here applies to contiguous rows:

But I would also like to add a rule where if the total number of rows in a group is under 3, it takes the mean of those rows, and compares it to the rows (of v1) immediately before and after that group, and whichever value is closest to the mean absorbs that group.

Otherwise, the criterion is not well defined -- if there is a group of size 2 but the two rows are not contiguous, there is no "immediately before and after that group" to compare against.

Building on Frank's cut and rleid(ct) :

#from Frank's answer
dt[,
    c("ct", "g") := {
        ct <- cut(v1, c(-Inf, 5, 10, Inf), right=FALSE, labels=FALSE)
        .(ct, rleid(ct))
    }
]

#calculate mean
dt[, c("N", "m") := .(.N, m=mean(v1)), by=.(ct, g)]

#store last/first value from prev/next for rolling join later
ct_dt <- dt[, c(.(ct=ct, g=g), shift(.(v1, g), c(1L, -1L)))][,
    .(near_v1=c(V3[1L], V4[.N]), new_ct=c(V5[1L], V6[.N])), .(ct, g)]

#update join for those with less than 3 rows
dt[N<3L, ct := ct_dt[.SD, on=.(ct, g, near_v1=m), roll="nearest", new_ct]]

#delete unwanted columns
dt[, c("g","N","m") := NULL]

output:

    v1 ct
 1:  3  1
 2:  1  1
 3:  1  1
 4:  5  1
 5:  6  1
 6: 12  3
 7: 13  3
 8: 11  3
 9: 10  3
10:  0  1
11:  2  1
12:  1  1
13:  3  1

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM