[英]conditional grouping by value and number of rows in R
Original data: 原始数据:
> dt = data.table(v1 = c(3,1,1,5,6,12,13,11,10,0,2,1,3))
> dt
v1
1: 3
2: 1
3: 1
4: 5
5: 6
6: 12
7: 13
8: 11
9: 10
10: 0
11: 2
12: 1
13: 3
I would like to put v1
into 3 groups based on value as follows: 我想根据值将v1
分为3组,如下所示:
> dt %>% mutate(group = case_when(v1 <5 ~ 1,
+ v1 >=5 & v1 <10 ~ 2,
+ v1 >= 10 ~3))
v1 group
1 3 1
2 1 1
3 1 1
4 5 2
5 6 2
6 12 3
7 13 3
8 11 3
9 10 3
10 0 1
11 2 1
12 1 1
13 3 1
But I would also like to add a rule where if the total number of rows in a group is under 3, it takes the mean of those rows, and compares it to the rows (of v1) immediately before and after that group, and whichever value is closest to the mean absorbs that group. 但是我还想添加一条规则,如果组中的行总数小于3,则取这些行的平均值,并将其与该组前后的行(v1)进行比较,无论哪个值最接近均值吸收该组。
In the example above, group 2 only has 2 rows, so I take their mean (5.5) and compare to the value above (1) and below (12). 在上面的示例中,组2仅包含2行,因此我取其平均值(5.5)并与上方(1)和下方(12)的值进行比较。 Since the smaller value is closer to the mean, those rows become group 1, making the desired output look as follows: 由于较小的值更接近平均值,因此这些行将成为组1,从而使所需的输出如下所示:
v1 group
1 3 1
2 1 1
3 1 1
4 5 1
5 6 1
6 12 3
7 13 3
8 11 3
9 10 3
10 0 1
11 2 1
12 1 1
13 3 1
I've made a few attempts to no avail and would really appreciate a dplyr
or data.table
solution. 我做了几次尝试都没有用,并且非常感谢dplyr
或data.table
解决方案。
One option using dplyr
could be to create a new column which would keep an account of row_number
and compare the v1
value of one row above and below of those groups which have less than 3 rows and assign the new groups based on it. 使用dplyr
一个选项可能是创建一个新列,该列将保留一个row_number
帐户,并比较少于3行的那些组的上下一行的v1
值,并基于该行分配新的组。 Here change
is the final output. 此处的change
是最终输出。
library(dplyr)
dt1 <- dt %>%
mutate(group = case_when(v1 < 5 ~ 1,
v1 >=5 & v1 <10 ~ 2,
v1 >= 10 ~3),
row = row_number())
dt1 %>%
group_by(group) %>%
mutate(change = if (n() < 3) {
c(dt1$group[first(row) - 1L], dt1$group[last(row) + 1L])[
which.min(c(abs(mean(v1) - dt1$v1[first(row) - 1L]),
abs(mean(v1) - dt1$v1[last(row) + 1L])))]
} else group)
# v1 group row change
# <dbl> <dbl> <int> <dbl>
# 1 3 1 1 1
# 2 1 1 2 1
# 3 1 1 3 1
# 4 5 2 4 1
# 5 6 2 5 1
# 6 12 3 6 3
# 7 13 3 7 3
# 8 11 3 8 3
# 9 10 3 9 3
#10 0 1 10 1
#11 2 1 11 1
#12 1 1 12 1
#13 3 1 13 1
First, compute the original grouping and aggregate: 首先,计算原始分组和聚合:
gDT = dt[, .(.N, m = mean(v1)), by=.(
ct = ct <- cut(v1, c(-Inf, 5, 10, Inf), right=FALSE, labels=FALSE),
g = rleid(ct)
)]
ct g N m
1: 1 1 3 1.666667
2: 2 2 2 5.500000
3: 3 3 4 11.500000
4: 1 4 4 1.500000
Flag groups to change and compare m
with the nearest unchanging groups above and below: 标记要更改的组并将m
与上方和下方的最接近的不变组进行比较:
gDT[, flag := N < 3]
gDT[, res := ct]
gDT[flag == TRUE, res := {
ffDT = gDT[flag == FALSE]
# nearest eligible rows going up and down -- possibly NA if at top or bottom
w_dn = ffDT[.(g = .SD$g - 1L), on=.(g), roll=TRUE, which=TRUE]
w_up = ffDT[.(g = .SD$g + 1L), on=.(g), roll=-Inf, which=TRUE]
# diffs of the mean against eligible rows up and down
diffs = lapply(list(dn = w_dn, up = w_up), function(w) abs(ffDT$m[w] - m))
# if/else for whichever is nearer, ties broken in favor of up
replace(ffDT$ct[w_dn], diffs$up < diffs$dn, ffDT$ct[w_up])
}]
ct g N m flag res
1: 1 1 3 1.666667 FALSE 1
2: 2 2 2 5.500000 TRUE 1
3: 3 3 4 11.500000 FALSE 3
4: 1 4 4 1.500000 FALSE 1
Creating a separate table like this makes it easy to check your work (look at flagged groups, check N
and ct
, compare m
with nearest unflagged neighbors, etc). 像这样创建一个单独的表可以很容易地检查您的工作(查看标记的组,检查N
和ct
,将m
与最近的未标记邻居进行比较,等等)。
To add back to the original table, one way is: 要添加回原始表,一种方法是:
dt[, res := gDT$res[ rleid(cut(v1, c(-Inf, 5, 10, Inf), right=FALSE, labels=FALSE)) ] ]
v1 ct res
1: 3 1 1
2: 1 1 1
3: 1 1 1
4: 5 2 1
5: 6 2 1
6: 12 3 3
7: 13 3 3
8: 11 3 3
9: 10 3 3
10: 0 1 1
11: 2 1 1
12: 1 1 1
13: 3 1 1
Details: The steps above are a lot more complicated than those in @RonakShah's answer since I assume that "group" here applies to contiguous rows: 详细信息:上面的步骤比@RonakShah的答案要复杂得多,因为我认为这里的“组”适用于连续的行:
But I would also like to add a rule where if the total number of rows in a group is under 3, it takes the mean of those rows, and compares it to the rows (of v1) immediately before and after that group, and whichever value is closest to the mean absorbs that group. 但是我还想添加一条规则,如果组中的行总数小于3,则取这些行的平均值,并将其与该组前后的行(v1)进行比较,无论哪个值最接近均值吸收该组。
Otherwise, the criterion is not well defined -- if there is a group of size 2 but the two rows are not contiguous, there is no "immediately before and after that group" to compare against. 否则,标准定义不明确-如果存在一组大小为2的组,但两行不是连续的,则没有“紧接在该组之前和之后”的比较对象。
Building on Frank's cut
and rleid(ct)
: 以Frank的cut
和rleid(ct)
:
#from Frank's answer
dt[,
c("ct", "g") := {
ct <- cut(v1, c(-Inf, 5, 10, Inf), right=FALSE, labels=FALSE)
.(ct, rleid(ct))
}
]
#calculate mean
dt[, c("N", "m") := .(.N, m=mean(v1)), by=.(ct, g)]
#store last/first value from prev/next for rolling join later
ct_dt <- dt[, c(.(ct=ct, g=g), shift(.(v1, g), c(1L, -1L)))][,
.(near_v1=c(V3[1L], V4[.N]), new_ct=c(V5[1L], V6[.N])), .(ct, g)]
#update join for those with less than 3 rows
dt[N<3L, ct := ct_dt[.SD, on=.(ct, g, near_v1=m), roll="nearest", new_ct]]
#delete unwanted columns
dt[, c("g","N","m") := NULL]
output: 输出:
v1 ct
1: 3 1
2: 1 1
3: 1 1
4: 5 1
5: 6 1
6: 12 3
7: 13 3
8: 11 3
9: 10 3
10: 0 1
11: 2 1
12: 1 1
13: 3 1
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.