简体   繁体   中英

luajit/lua5.1/lua5.2/lua5.3 memory issue for RNN code

I have been running a piece of code, train.lua, found here: https://github.com/karpathy/char-rnn/blob/master/train.lua

This is a character-wise language prediction model based off of the SRNNs/LSTMs. It had been working perfectly fine on OSX with CPU until I tried implementing a word-wise prediction model instead. Namely, the network predicts the next word, as opposed to the next alphabet. The number of vocabs (possible outcomes) went up to 13320 and the number of parameters also increased to 39963. With Luajit, I got an error message "not enough memory", and I was looking around for a solution. I found the issue of the memory limit on Luajit brought up here: https://github.com/karpathy/char-rnn/issues/80

So I removed torch and installed plain lua. However, neither LUA51, LUA52, nor LUA53 worked. I ran into the same memory issue. It just says "Kill: 9" every time I run the training code. In particular, the issue arises when I get it to create T (the sequence length or the time steps) hidden layers, which share the same weights, using the "model_utils.clone_many_times" function in the util/model_utils.lua file.

In my case, the function runs up to the point where it clones 7 hidden layers, and kills the process there. I set the rnn_size and the batch_size to be both 1. Of course, I want to run much bigger networks, but the code still fails with this small size.

Update: Here is the workaround I am working on.

The cloning process seems somewhat redundant as it stores T hidden layers. Maybe we can change the function in a way that it only carries activations in the units as opposed to the entire layers through T time steps. I feel the only issue is backprop. Activation levels of the hidden units are carried over by the table, init_state_global , from batch to batch. So we somehow need to establish back-propagation over multiple batches.

Here is a workaround I found. Everything else equal, the results I got were almost the same as the original one except some float precision errors for some reason. It saves memory (seq_length does not even affect the memory size). I set the number of clones in the "model_utils.clone_many_times" function to be 1 (so we probably don't even need this memory-consuming function anymore), and just store the hidden units activation for backprop.

function feval(x)
if x ~= params then
    params:copy(x)
end
grad_params:zero()

------------------ get minibatch -------------------
local x, y = loader:next_batch(1)
x,y = prepro(x,y) -- seq_length by batch_size tensor
------------------- forward pass -------------------
local rnn_state = {[0] = init_state_global}
local predictions = {}           -- softmax outputs
local loss = 0
local hidden_units = {}

for t=1,opt.seq_length do
    clones.rnn[1]:training() -- make sure we are in correct mode (this is cheap, sets flag)
    local lst = clones.rnn[1]:forward{x[t], unpack(rnn_state[t-1])}
    rnn_state[t] = {}
    for i=1,#init_state do table.insert(rnn_state[t], lst[i]) end -- extract the state, without output
    hidden_units[t] = {}
    local j = 1
    for k = 1, #clones.rnn[1].modules do
         if clones.rnn[1].modules[k].output then
             if not (type(clones.rnn[1].modules[k].output) == 'table') then
                hidden_units[t][j] = clones.rnn[1].modules[k].output:clone() 
            else
                hidden_units[t][j] = {}
                for l=1, #clones.rnn[1].modules[k].output do
                    hidden_units[t][j][l] = clones.rnn[1].modules[k].output[l]:clone() 
                end
            end
            j = j+1

         end
    end

    predictions[t] = lst[#lst] -- last element is the prediction
    loss = loss + clones.criterion[1]:forward(predictions[t], y[t])
end
loss = loss / opt.seq_length

------------------ backward pass -------------------
-- initialize gradient at time t to be zeros (there's no influence from future)
local drnn_state = {[opt.seq_length] = clone_list(init_state, true)} -- true also zeros the clones
for t=opt.seq_length,1,-1 do
    -- backprop through loss, and softmax/linear
    local j = 1

    for k = 1, #clones.rnn[1].modules do
         if clones.rnn[1].modules[k].output then
            clones.rnn[1].modules[k].output = hidden_units[t][j]
            j = j+1
         end
    end

    local doutput_t = clones.criterion[1]:backward(predictions[t], y[t])
    table.insert(drnn_state[t], doutput_t)
    local dlst = clones.rnn[1]:backward({x[t], unpack(rnn_state[t-1])}, drnn_state[t])
    drnn_state[t-1] = {}
    for k,v in pairs(dlst) do
         for k=1, #clones.rnn[1].modules[k].output do
                hidden_units[t][j][k] = clones.rnn[1].modules[k].output:clone() 
                end
            end
            j = j+1

         end
    end

    predictions[t] = lst[#lst] -- last element is the prediction
    loss = loss + clones.criterion[1]:forward(predictions[t], y[t])
end
loss = loss / opt.seq_length
------------------ backward pass -------------------
-- initialize gradient at time t to be zeros (there's no influence from future)
local drnn_state = {[opt.seq_length] = clone_list(init_state, true)} -- true also zeros the clones
for t=opt.seq_length,1,-1 do
    -- backprop through loss, and softmax/linear
    local j = 1

    for k = 1, #clones.rnn[1].modules do
         if clones.rnn[1].modules[k].output then
            clones.rnn[1].modules[k].output = hidden_units[t][j]
            j = j+1
         end
    end

    local doutput_t = clones.criterion[1]:backward(predictions[t], y[t])
    table.insert(drnn_state[t], doutput_t)
    local dlst = clones.rnn[1]:backward({x[t], unpack(rnn_state[t-1])}, drnn_state[t])
    drnn_state[t-1] = {}
           for k = 1, #clones.rnn[1].modules do
         if clones.rnn[1].modules[k].output then
            clones.rnn[1].modules[k].output = hidden_units[t][j]
            j = j+1
         end
    end

    local doutput_t = clones.criterion[1]:backward(predictions[t], y[t])
    table.insert(drnn_state[t], doutput_t)
    local dlst = clones.rnn[1]:backward({x[t], unpack(rnn_state[t-1])}, drnn_state[t])
    drnn_state[t-1] = {}
    for k,v in pairs(dlst) do
        if k > 1 then -- k == 1 is gradient on x, which we dont need
            -- note we do k-1 because first item is dembeddings, and then follow the 
            -- derivatives of the state, starting at index 2. I know...
            drnn_state[t-1][k-1] = v
        end
    end
end
------------------------ misc ----------------------
-- transfer final state to initial state (BPTT)
init_state_global = rnn_state[#rnn_state] -- NOTE: I don't think this needs to be a clone, right?
-- grad_params:div(opt.seq_length) -- this line should be here but since we use rmsprop it would have no effect. Removing for efficiency
-- clip gradient element-wise
--Lets not clip gradient this time grad_params:clamp(-opt.grad_clip, opt.grad_clip)
return loss, grad_params
end

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM