You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Nov 2, 2018. It is now read-only.
We give three different APIs for constructing a network with recurrent connections. The first two give access to cudnn bindings as well. The purpose of this is to allow easy usage of the torch.cudnn RNN bindings for quick baselines.
1. The nn.{RNN, LSTM, GRU} interface can be used to construct recurrent networks with the same number of hidden units across all layers.
2. The rnnlib.recurrentnetwork interface can be used to construct recurrent networks with any shape. Both the previous and this interface take care of hidden state saving for you.
3. The nn.SequenceTable interface can be used to chain computations as a 'scan' would. The nn.RecurrentTable constructor is simply a lightweight wrapper that clones the recurrent module over time for you. However, do take note that this is the lowest-level interface and you will have to call rnnlib.setupRecurrent(model, initializationfunctions) in order to setup the recurrent hidden state behaviour.
localrnnlib=require'rnnlib'-- The table of cells is fed to each level of the recurrent network to construct each layer.-- The table of initialization functions helps with the construction of the hidden inputs.localcells, initfunctions= {}, {}
cells[1], initfunctions[1] =rnnlib.cell.LSTM(256, 512)
cells[2], initfunctions[2] =rnnlib.cell.LSTM(512, 512)
cells[3], initfunctions[3] =rnnlib.cell.LSTM(512, 512)
locallstm=nn.SequenceTable{
dim=1,
modules= {
nn.RecurrentTable{ dim=2, module=rnnlib.cell.gModule(cells[1]) },
nn.RecurrentTable{ dim=2, module=rnnlib.cell.gModule(cells[2]) },
nn.RecurrentTable{ dim=2, module=rnnlib.cell.gModule(cells[3]) },
},
}
rnnlib.setupRecurrent(lstm, initfunctions)
Train the model
All the modules in this library adhere to the nn.Container or nn.Module API.
Given a recurrent network constructed in one of the above ways, you can use a lookup table and linear layer to train it as follows:
localmutils=require'rnnlib.mutils'-- The vocabulary size.localvocabsize=10000-- The dimensionality of the last hidden layer.locallasthid=512-- The dimensionality of the input embeddings.localinsize=256-- The sequence length.localseqlen=32-- The batch size.localbsz=32-- The lookup table.locallut=nn.LookupTable(vocabsize, insize)
-- The decoder.localdecoder=nn.Linear(lasthid, vocabsize)
-- The full model.localmodel=nn.Sequential()
:add(mutils.batchedinmodule(lstm, lut))
:add(nn.SelectTable(2))
:add(nn.SelectTable(-1))
:add(nn.JoinTable(1))
:add(decoder)
model:cuda()
-- This returns a flattened view of the output tensor.-- If you want this to be of a different shape, you can add an nn.View at the end.-- Generate the input to the model.localinput=torch.range(1, seqlen*bsz)
:resize(seqlen, bsz)
:cudaLong()
-- Create and initialize hiddens to zero.lstm:initializeHidden(bsz)
-- Perform the forward pass.localoutput=model:forward{ lstm.hiddenbuffer, input }
-- This is just an example, normally you would not use the output as the gradOutput.-- But the gradOutput should have the same shape as the model output.model:backward({ lstm.hiddenbuffer, input }, output)
model:updateParameters(0.1)
Benchmarks
We benchmark against the good work done in similar RNN libraries,
rnn and torch-rnn,
on the Penn Tree Bank dataset with an LSTM language model with hidden dimensions of 256, 512, 1024, 2048, and 4096.
All models have a sequence length of 20, a batch size of 64, 2 layers, and were averaged over 1000 iterations.
Join the community
See the CONTRIBUTING file for how to help out.
License
torch-rnnlib is BSD-licensed. We also provide an additional patent grant.
About
This library provides utilities for creating and manipulating RNNs to model sequential data.