简体   繁体   中英

Lua commands , what do they do?

I am unfamiliar with lua. but the author of the article used lua.

can you help me understand what those two lines do:

what does replicate(x,batch_size) do?

what does x = x:resize(x:size(1), 1):expand(x:size(1), batch_size) do?

original source code can be found here https://github.com/wojzaremba/lstm/blob/master/data.lua

This basically boils down to simple maths and looking up a few functions in the torch manual.

Ok I'm bored so...

replicate(x,batch_size) as defined in https://github.com/wojzaremba/lstm/blob/master/data.lua

-- Stacks replicated, shifted versions of x_inp
-- into a single matrix of size x_inp:size(1) x batch_size.
local function replicate(x_inp, batch_size)
   local s = x_inp:size(1)
   local x = torch.zeros(torch.floor(s / batch_size), batch_size)
   for i = 1, batch_size do
     local start = torch.round((i - 1) * s / batch_size) + 1
     local finish = start + x:size(1) - 1
     x:sub(1, x:size(1), i, i):copy(x_inp:sub(start, finish))
   end
   return x
end

This code is using the Torch framework.

x_inp:size(1) returns the size of dimension 1 of the Torch tensor (a potentially multi-dimensional matrix) x_inp .

See https://cornebise.com/torch-doc-template/tensor.html#toc_18

So x_inp:size(1) gives you the number of rows in x_inp . x_inp:size(2) , would give you the number of columns...

local x = torch.zeros(torch.floor(s / batch_size), batch_size)

creates a new two-dimensional tensor filled with zeros and creates a local reference to it, named x The number of rows is calculated from s , x_inp 's row count and batch_size . So for your example input it turns out to be floor(11/2) = floor(5.5) = 5 .

The number of columns in your example is 2 as batch_size is 2.

torch.

So simply spoken x is the 5x2 matrix

0 0
0 0
0 0
0 0
0 0

The following lines copy x_inp 's contents into x .

for i = 1, batch_size do
  local start = torch.round((i - 1) * s / batch_size) + 1
  local finish = start + x:size(1) - 1
  x:sub(1, x:size(1), i, i):copy(x_inp:sub(start, finish))
end

In the first run, start evaluates to 1 and finish to 5 , as x:size(1) is of course the number of rows of x which is 5 . 1+5-1=5 In the second run, start evaluates to 6 and finish to 10

So the first 5 rows of x_inp (your first batch) are copied into the first column of x and the second batch is copied into the second column of x

x:sub(1, x:size(1), i, i) is the sub-tensor of x , row 1 to 5, column 1 to 1 and in the second run row 1 to 5, column 2 to 2 (in your example). So it's nothing more than the first and second columns of x

See https://cornebise.com/torch-doc-template/tensor.html#toc_42

:copy(x_inp:sub(start, finish))

copies the elements from x_inp into the columns of x .

So to summarize you take an input tensor and you split it into batches which are stored in a tensor with one column for each batch.

So with x_inp

0
1
2
3
4
5
6
7
8
9
10

and batch_size = 2

x is

0 5
1 6
2 7
3 8
4 9

Further:

local function testdataset(batch_size)
  local x = load_data(ptb_path .. "ptb.test.txt")
  x = x:resize(x:size(1), 1):expand(x:size(1), batch_size)
  return x
end

Is another function that loads some data from a file. This x is not related to the x above other than both being a tensor.

Let's use a simple example:

x being

1
2
3
4

and batch_size = 4

x = x:resize(x:size(1), 1):expand(x:size(1), batch_size)

First x will be resized to 4x1, read https://cornebise.com/torch-doc-template/tensor.html#toc_36

And then it is expanded to 4x4 by duplicating the first row 3 times.

Resulting in x being the tensor

1 1 1 1
2 2 2 2
3 3 3 3
4 4 4 4

read https://cornebise.com/torch-doc-template/tensor.html#toc_49

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM