简体   繁体   English


[英]Reducing batch size in pytorch

I am new to programming in pytorch. 我是pytorch编程的新手。 I am getting this error which says cuda out of memory . 我收到此错误,提示cuda内存不足 So I have to reduce the batch size. 所以我必须减小批量大小。 Can someone tell me how to do it in python code? 有人可以告诉我如何使用python代码吗? I also don't know my current batch size. 我也不知道我当前的批量大小。

ps I am trying to run the Deep Image Prior's super-resolution. ps我正在尝试运行Deep Image Prior的超分辨率。 Here's the code. 这是代码。

The error I am getting is when running the optimization. 我得到的错误是运行优化时。 It says 它说

RuntimeError: Cuda out of memory. RuntimeError:Cuda内存不足。

from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import argparse
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
from models import *

import torch
import torch.optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import warnings

from skimage.measure import compare_psnr
from models.downsampler import Downsampler

from utils.sr_utils import *

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

imsize = -1
factor = 16 # 8
enforse_div32 = 'CROP' # we usually need the dimensions to be divisible by a power of two (32 in this case)
PLOT = True
path_to_image = '/home/smitha/deep-image-prior/tnew.tif'
imgs = load_LR_HR_imgs_sr(path_to_image , imsize, factor, enforse_div32)
imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np'] = get_baselines(imgs['LR_pil'], imgs['HR_pil'])
if PLOT:
    plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], imgs['sharp_np'], imgs['nearest_np']], 4,12);
    print ('PSNR bicubic: %.4f   PSNR nearest: %.4f' %  (
                                    compare_psnr(imgs['HR_np'], imgs['bicubic_np']), 
                                    compare_psnr(imgs['HR_np'], imgs['nearest_np'])))
input_depth = 8
INPUT =     'noise'
pad   =     'reflection'
OPT_OVER =  'net'
LR = 5
tv_weight = 0.0
OPTIMIZER = 'adam'
if factor == 16: 
    num_iter = 10
    reg_noise_std = 0.01
elif factor == 8:
    num_iter = 40
    reg_noise_std = 0.05
    assert False, 'We did not experiment with other factors'
net_input = get_noise(input_depth, INPUT, (imgs['HR_pil'].size[1],  imgs['HR_pil'].size[0])).type(dtype).detach()
NET_TYPE = 'skip' # UNet, ResNet
net = get_net(input_depth, 'skip', pad,
mse = torch.nn.MSELoss().type(dtype)

img_LR_var = np_to_torch(imgs['LR_np']).type(dtype)

downsampler = Downsampler(n_planes=3, factor=factor,  kernel_type=KERNEL_TYPE, phase=0.5, preserve_size=True).type(dtype) 
def closure():
    global i, net_input
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)

    out_HR = net(net_input)
    out_LR = downsampler(out_HR)

    total_loss = mse(out_LR, img_LR_var) 

    if tv_weight > 0:
        total_loss += tv_weight * tv_loss(out_HR)


    # Log
    psnr_LR = compare_psnr(imgs['LR_np'], torch_to_np(out_LR))
    psnr_HR = compare_psnr(imgs['HR_np'], torch_to_np(out_HR))
    print ('Iteration %05d    PSNR_LR %.3f   PSNR_HR %.3f' % (i, psnr_LR, psnr_HR), '\r', end='')

    # History
    psnr_history.append([psnr_LR, psnr_HR])

    if PLOT and i % 100 == 0:
        out_HR_np = torch_to_np(out_HR)
        plot_image_grid([imgs['HR_np'], imgs['bicubic_np'], np.clip(out_HR_np, 0, 1)], factor=13, nrow=3)

    i += 1

    return total_loss   

psnr_history = [] 
net_input_saved = net_input.detach().clone()
noise = net_input.clone()
i = 0
p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)
out_HR_np = np.clip(torch_to_np(net(net_input)), 0, 1)
result_deep_prior = put_in_center(out_HR_np, imgs['orig_np'].shape[1:])
             out_HR_np], factor=4, nrow=1);

The batch size depends on the model. 批次大小取决于型号。 Typically, it's the first dimension of your input tensors . 通常,它是输入张量第一维 Your model uses different names than I'm used to, some of which are general terms, so I'm not sure of your model topology or usage. 您的模型使用的名称与我使用的名称不同,其中一些是通用术语,因此我不确定您的模型拓扑或用法。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM