简体   繁体   中英

pytorch autograd : getting pixel grid tensor from coordinates tensor in a differentiable way

my model output contains coordinates of rectangles within a canvas, and I am trying to get a pixelwise representation of this output from the coordinates representation, before applying the loss on the pixelwise representation:

    # get prediction
    ypred=forward(x,w)

    # rasterize pred + test
    ytrain=rasterize(ytrain,300,600)
    ypred=rasterize(ypred,300,600)
    
    # update loss
    loss = get_loss(ytrain, ypred)

    # get gradient
    loss.backward()
    
    # update weights
    with torch.no_grad():
        w -= lr * w.grad

I've build those two alternative rasterization toy functions:

# rasterize tensor : for loop
def rasterize_toy(tn,w,h):
    nsamples=tn.size()[0]
    #nsamples=4
    vtn=torch.zeros(nsamples,h,w,3,dtype=torch.float, requires_grad=True)
    #vtn=torch.empty(nsamples,h,w,3,dtype=torch.float, requires_grad=True)
    for i in range(nsamples): # for each sample
        top=(tn[i]*5).long()
        vtn[i]=add_tn_bg(vtn[i],h,w)
        #vtn[i]=get_tn_bg(h,w)
        vtn[i,top:,:,0]=255/255
        vtn[i,top:,:,1]=255/255
        vtn[i,top:,:,2]=255/255
    #vtn=vtn.float()
    return vtn

# rasterize tensor : index_put
def rasterize_toy2(tn,w,h):
    nsamples=tn.size()[0]
    top=(tn[0]*10).long()
    print("top",top)
    v=100#1.0 #255
    #vtn=torch.zeros(nsamples,h,w,3,dtype=torch.float, requires_grad=True)
    vtn=torch.zeros(nsamples,h,w,dtype=torch.float, requires_grad=True)
    indices=[(torch.ones(w)*top).long(),
             torch.arange(0,w).long()]
    values=torch.ones(w)*v    
    vtn[0]=vtn[0].index_put(indices, values)
    return vtn

but they are both generating this error when calling loss.backward() after the rasterisation step:

RuntimeError: leaf variable has been moved into the graph interior

My issue seems very similar to this one which has not been solved yet:

source 1

PyTorch: Differentiable operations to go from coordinate tensor to grid tensor

I've also checked the following sources outside from stack overflow:

source 2

link: GitHub - ksheng-/fast-differentiable-rasterizer: differentiable bezier curve rasterizer with PyTorch

problem: while this git propose ways of rasterizing data structures, it seems to me it doesn't allow to use differentiable variables as indexes of the final rasterized image.

source 3

link: https://discuss.pytorch.org/t/leaf-variable-moved-into-graph-interior/17489

problem:

  • masked_scatter, gather and grid_sample functions seem not to match what I am trying to do
  • index_put seems to match my needs but I based my second rasterization function on it and it generates the same error as the rasterization function based on for loops

Thanks in advance for your help

I managed to rasterize the model output before feeding it to the loss.backward() operation like this:

The solution is kind of nasty and requires to initialize the coordinate representation x in the shape of the pixel-wise representation, with x values occupying only a fraction of this shape. Then the rasterization function writes new values within the complete pixel-wise representation, from the few initialization values that it contains, coupled with the updated weights.

Here is the training protocol:

    # Training
    lr = 0.01
    for iepoch in range(nepochs):

        # get prediction
        ypred=mforward(x,w)

        # rasterize pred + test
        y=mrast(y,300,600)
        ypred=mrast(ypred,300,600)
        
        # update loss
        loss = get_loss(y, ypred)

        # get gradient
        loss.backward()
        
        # update weights
        with torch.no_grad():
            w -= lr * w.grad

        # reset gradient #37
        w.grad.zero_()

Here is the rasterization function:

def rasterize_toy3(tn,w,h):
    htn=torch.ones(nsamples,h,w,nc,dtype=torch.float)*0
    for i in range(nsamples): # for each sample
        top=(tn[i,0,0,0]*5).long()
        tn[i,top:,:,0]=100/255
        tn[i,top:,:,1]=100/255
        tn[i,top:,:,2]=100/255
    return tn

And here is the data initialisation function:

def set_data():
    
    ## data : origin
    if 1==0:
        x = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
        y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)
    
    ## data : alt
    if 1==1:
        
        nsamples,h,w,nc=4,60,30,3
        x=torch.ones(nsamples,h,w,nc,dtype=torch.float)*0
        x[:,0,0,0]=torch.tensor([1, 2, 3, 4], dtype=torch.float32)
        y=torch.ones(nsamples,h,w,nc,dtype=torch.float)*0
        y[:,0,0,0]=torch.tensor([2, 4, 6, 8], dtype=torch.float32)

    return x,y

If anyone has a more elegant solution, i'm interested.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM