
[英]query 1st, 2nd, 3rd place with tie break using 2nd and 3rd lowest with each time on one row
[英]A query to perform sampling, choosing every 3rd row based on condition using postgres
我有一个包含图像 ID 和分数的表。
正样本是得分接近 1 的图像 id,负样本是得分接近 0 的图像 id。
为了采样,我们按分数的降序排列图像 ID,并选择每第 3 个图像 ID,这构成了一个正样本。 为了获得负样本,我们按分数的递增顺序对图像 ID 进行排序,并选择每第 3 个图像 ID,这将构成一个负样本。
发布架构以供参考:
如果不存在则创建表 unlabeled_image_predictions ( image_id int, score float);
插入 unlabeled_image_predictions (image_id, score) VALUES
('828','0.3149'), ('705','0.9892'), ('46', '0.5616'), ('594', '0.7670'), ('232', '0.1598'), ('524','0.9876'), ('306','0.6487'), ('132','0.8823'), ('906','0.8394'), ('272', '0.9778'), ('616', '0.1003'), ('161', '0.7113'), ('715', '0.8921'), ('109', '0.1151'), ('424', '0.7790'), ('609', '0.5241'), ('63', '0.2552'), ('276', '0.2672'), ('701', '0.0758'), ('554', '0.4418'), ('998', '0.0379'), ('809', '0.1058'), ('219', '0.7143'), ('402', '0.7655'), ('363', '0.2661'), ('624', '0.8270'), ('640', '0.8790'), ('913', '0.2421'), ('439', '0.3387'), ('464', '0.3674'), ('405', '0.6929'), ('986', '0.8931'), ('344', '0.3761'), ('847', '0.4889'), ('482', '0.5023'), ('823','0.3361'), ('617','0.0218'), ('47', '0.0072'), ('867','0.4050'), ('96','0.4498'), ('126','0.3564'), ('943', '0.0452'), ('115','0.5309'), ('417', '0.7168'), ('706', '0.9649'), ('166', '0.2507'), ('991', '0.4191'), ('465', '0.0895'), ('53', '0.8169'), ('971', '0.9871');
这是预期的 output 的样子:
我尝试了以下查询
SELECT 弱标签,
CASE WHEN weak_label = 1 THEN (SELECT json_agg(a.image_id) FROM ( SELECT *, row_number() OVER(ORDER BY score DESC) AS row FROM unlabeled_image_predictions ) a WHERE a.row % 3 = 0 ) ELSE (SELECT json_agg(c .image_id) FROM ( SELECT *, row_number() OVER(ORDER BY score ASC) AS row FROM unlabeled_image_predictions ) c WHERE c.row % 3 = 0 ) END AS label
从 (
SELECT image_id, CASE WHEN WHEN score < 0.5100 THEN 0 ELSE 1 END AS weak_label FROM unlabeled_image_predictions) mod
它产生以下 output
模式(PostgreSQL v15)
CREATE TABLE IF NOT EXISTS unlabeled_image_predictions (
image_id int,
score float);
INSERT INTO unlabeled_image_predictions (image_id, score) VALUES
('828','0.3149'), ('705','0.9892'), ('46', '0.5616'), ('594', '0.7670'), ('232','0.1598'), ('524','0.9876'), ('306','0.6487'),
('132','0.8823'), ('906','0.8394'), ('272', '0.9778'), ('616', '0.1003'), ('161', '0.7113'), ('715', '0.8921'), ('109', '0.1151'),
('424','0.7790'), ('609', '0.5241'), ('63', '0.2552'), ('276','0.2672'), ('701','0.0758'), ('554','0.4418'), ('998', '0.0379'),
('809','0.1058'), ('219','0.7143'), ('402', '0.7655'), ('363', '0.2661'), ('624', '0.8270'), ('640','0.8790'), ('913','0.2421'),
('439','0.3387'), ('464', '0.3674'), ('405', '0.6929'), ('986', '0.8931'), ('344', '0.3761'), ('847', '0.4889'), ('482', '0.5023'),
('823','0.3361'), ('617','0.0218'), ('47', '0.0072'), ('867','0.4050'), ('96','0.4498'), ('126','0.3564'), ('943', '0.0452'),
('115','0.5309'), ('417', '0.7168'), ('706','0.9649'), ('166', '0.2507'), ('991', '0.4191'), ('465', '0.0895'), ('53', '0.8169'),
('971','0.9871');
查询#1
SELECT weak_label,
CASE
WHEN weak_label = 1 THEN (SELECT json_agg(a.image_id)
FROM (
SELECT *, row_number() OVER(ORDER BY score DESC) AS row
FROM unlabeled_image_predictions
) a
WHERE a.row % 3 = 0
)
ELSE (SELECT json_agg(c.image_id)
FROM (
SELECT *, row_number() OVER(ORDER BY score ASC) AS row
FROM unlabeled_image_predictions
) c
WHERE c.row % 3 = 0
)
END AS label
FROM (
SELECT image_id,
CASE
WHEN score < 0.5100 THEN 0 ELSE 1
END AS weak_label
FROM unlabeled_image_predictions) mod;
弱标签 | label |
---|---|
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
0 | 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
1个 | 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998 |
然而,这不是预期的 output 应该看起来的样子。
发布预期的output以供参考:
如何获得该结果有点模糊。 根据你的描述:
with negatives (image_id, score, weak_label, rowNo) as
(SELECT image_id,
score,
0 AS weak_label,
row_number() over (order by score asc)
FROM unlabeled_image_predictions
where score < 0.5100),
positives (image_id, score, weak_label, rowNo) as
(SELECT image_id,
score,
1 AS weak_label,
row_number() over (order by score desc)
FROM unlabeled_image_predictions
where score >= 0.5100),
combined as
(select * from negatives
union
select * from positives
)
select image_id, weak_label
from combined
where rowNo % 3 = 1
order by image_id;
会产生:
image_id | 弱标签 |
---|---|
47 | 0 |
63 | 0 |
96 | 0 |
115 | 1个 |
126 | 0 |
232 | 0 |
272 | 1个 |
405 | 1个 |
417 | 1个 |
424 | 1个 |
616 | 0 |
705 | 1个 |
715 | 1个 |
828 | 0 |
867 | 0 |
906 | 1个 |
943 | 0 |
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.