简体   繁体   English

如何训练具有可变 output 大小的神经网络

[英]How to train a neural network with a variable output size

I have a working CNN-LSTM model trying to predict keypoints of human bodyparts on videos.我有一个有效的 CNN-LSTM model 试图预测视频中人体部位的关键点。

Currently, I have four keypoints as labels right hand, left hand, head and pelvis.目前,我有四个关键点作为标签右手、左手、头部和骨盆。 The problem is that on some frames I can't see the four parts of the human that I want to label, so by default I set those values to (0,0) (which is a null coordinate).问题是在某些帧上我看不到我想要 label 的人的四个部分,所以默认情况下我将这些值设置为 (0,0)(这是一个 null 坐标)。

The problem that I faced was the model taking in account those points and trying to regress on them while being in a sequence.我面临的问题是 model 考虑到这些点并试图在序列中回归它们。

Thus, I removed the (0,0) points in the loss calculation and the gradient retropropagation and it works much better.因此,我删除了损失计算和梯度逆向传播中的 (0,0) 点,效果更好。

The problem is that the Four points are still predicted, so I am trying to know by any means how to make it predict a variable number of keypoints.问题是四个点仍然被预测,所以我试图以任何方式知道如何让它预测可变数量的关键点。

I thought of adding a third parameter (is it visible?), but it will probably add some complexity and loose the model.我想过添加第三个参数(它是否可见?),但它可能会增加一些复杂性并松散 model。

I think that you'll have to write a custom loss function that computes the loss between points only when the target coordinates are not null.我认为您必须编写一个自定义损失 function,仅当目标坐标不是 null 时才计算点之间的损失。

See PyTorch custom loss function on writing custom losses.请参阅PyTorch 自定义损失 function编写自定义损失。

Something like:就像是:

def loss(outputs, labels):
    err = 0
    n = 0
    for xo, xt in zip(outputs, labels):
        if xt.values == torch.zeros(2):  # null coord
            continue
        err += torch.nn.functional.mse_loss(xo, xt)
        n += 1
    return (err / n)

This is pseudo-code only.这只是伪代码。 An alternative form which will avoid the loop is to have an explicit binary vector (as suggested by @leleogere) that you can then multiply by the loss on each coordinate before reducing.避免循环的另一种形式是有一个明确的二进制向量(如@leleogere 建议的那样),然后您可以在减少之前乘以每个坐标上的损失。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM