繁体   English   中英

执行采样的查询,使用 postgres 根据条件选择每 3 行

[英]A query to perform sampling, choosing every 3rd row based on condition using postgres

我有一个包含图像 ID 和分数的表。

桌子

正样本是得分接近 1 的图像 id,负样本是得分接近 0 的图像 id。

为了采样,我们按分数的降序排列图像 ID,并选择每第 3 个图像 ID,这构成了一个正样本。 为了获得负样本,我们按分数的递增顺序对图像 ID 进行排序,并选择每第 3 个图像 ID,这将构成一个负样本。

发布架构以供参考:

如果不存在则创建表 unlabeled_image_predictions ( image_id int, score float);

插入 unlabeled_image_predictions (image_id, score) VALUES

('828','0.3149'), ('705','0.9892'), ('46', '0.5616'), ('594', '0.7670'), ('232', '0.1598'), ('524','0.9876'), ('306','0.6487'), ('132','0.8823'), ('906','0.8394'), ('272', '0.9778'), ('616', '0.1003'), ('161', '0.7113'), ('715', '0.8921'), ('109', '0.1151'), ('424', '0.7790'), ('609', '0.5241'), ('63', '0.2552'), ('276', '0.2672'), ('701', '0.0758'), ('554', '0.4418'), ('998', '0.0379'), ('809', '0.1058'), ('219', '0.7143'), ('402', '0.7655'), ('363', '0.2661'), ('624', '0.8270'), ('640', '0.8790'), ('913', '0.2421'), ('439', '0.3387'), ('464', '0.3674'), ('405', '0.6929'), ('986', '0.8931'), ('344', '0.3761'), ('847', '0.4889'), ('482', '0.5023'), ('823','0.3361'), ('617','0.0218'), ('47', '0.0072'), ('867','0.4050'), ('96','0.4498'), ('126','0.3564'), ('943', '0.0452'), ('115','0.5309'), ('417', '0.7168'), ('706', '0.9649'), ('166', '0.2507'), ('991', '0.4191'), ('465', '0.0895'), ('53', '0.8169'), ('971', '0.9871');

这是预期的 output 的样子:

预计 Output

我尝试了以下查询

SELECT 弱标签,
CASE WHEN weak_label = 1 THEN (SELECT json_agg(a.image_id) FROM ( SELECT *, row_number() OVER(ORDER BY score DESC) AS row FROM unlabeled_image_predictions ) a WHERE a.row % 3 = 0 ) ELSE (SELECT json_agg(c .image_id) FROM ( SELECT *, row_number() OVER(ORDER BY score ASC) AS row FROM unlabeled_image_predictions ) c WHERE c.row % 3 = 0 ) END AS label

从 (
SELECT image_id, CASE WHEN WHEN score < 0.5100 THEN 0 ELSE 1 END AS weak_label FROM unlabeled_image_predictions) mod

它产生以下 output

模式(PostgreSQL v15)

CREATE TABLE IF NOT EXISTS unlabeled_image_predictions (
  image_id int,
  score float);

INSERT INTO unlabeled_image_predictions (image_id, score) VALUES

('828','0.3149'), ('705','0.9892'), ('46', '0.5616'), ('594', '0.7670'), ('232','0.1598'), ('524','0.9876'), ('306','0.6487'),
('132','0.8823'), ('906','0.8394'), ('272', '0.9778'), ('616', '0.1003'), ('161', '0.7113'), ('715', '0.8921'), ('109', '0.1151'),
('424','0.7790'), ('609', '0.5241'), ('63', '0.2552'), ('276','0.2672'), ('701','0.0758'), ('554','0.4418'), ('998', '0.0379'),
('809','0.1058'), ('219','0.7143'), ('402', '0.7655'), ('363', '0.2661'), ('624', '0.8270'), ('640','0.8790'), ('913','0.2421'),
('439','0.3387'), ('464', '0.3674'), ('405', '0.6929'), ('986', '0.8931'), ('344', '0.3761'), ('847', '0.4889'), ('482', '0.5023'),
('823','0.3361'), ('617','0.0218'), ('47', '0.0072'), ('867','0.4050'), ('96','0.4498'), ('126','0.3564'), ('943', '0.0452'),
('115','0.5309'), ('417', '0.7168'), ('706','0.9649'), ('166', '0.2507'), ('991', '0.4191'), ('465', '0.0895'), ('53', '0.8169'),
('971','0.9871');

查询#1

SELECT weak_label,  
CASE
    WHEN weak_label = 1 THEN (SELECT json_agg(a.image_id)
        FROM (
         SELECT *, row_number() OVER(ORDER BY score DESC) AS row
         FROM unlabeled_image_predictions
            ) a
WHERE a.row % 3 = 0
)
ELSE (SELECT json_agg(c.image_id)
        FROM (
         SELECT *, row_number() OVER(ORDER BY score ASC) AS row
         FROM unlabeled_image_predictions
            ) c
WHERE c.row % 3 = 0
)
END AS label

FROM (  
SELECT image_id,
 CASE 
    WHEN score < 0.5100 THEN 0 ELSE 1
END AS weak_label
FROM unlabeled_image_predictions) mod;
弱标签 label
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
0 998,465,109,166,276,439,344,554,482,46,161,402,53,640,986,971
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998
1个 971,986,640,53,402,161,46,482,554,344,439,276,166,109,465,998

在 DB Fiddle 上查看

然而,这不是预期的 output 应该看起来的样子。

发布预期的output以供参考:

预计 Output

如何获得该结果有点模糊。 根据你的描述:

with negatives (image_id, score, weak_label, rowNo) as
         (SELECT image_id,
                 score,
                 0 AS weak_label,
                 row_number() over (order by score asc)
          FROM unlabeled_image_predictions
          where score < 0.5100),
     positives (image_id, score, weak_label, rowNo) as
         (SELECT image_id,
                 score,
                 1 AS weak_label,
                 row_number() over (order by score desc)
          FROM unlabeled_image_predictions
          where score >= 0.5100),
     combined as
         (select * from negatives
          union
          select * from positives
          )
select image_id, weak_label
from combined
where rowNo % 3 = 1
order by image_id;

会产生:

image_id 弱标签
47 0
63 0
96 0
115 1个
126 0
232 0
272 1个
405 1个
417 1个
424 1个
616 0
705 1个
715 1个
828 0
867 0
906 1个
943 0

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2025 STACKOOM.COM