Here's some code I've been using to extract the last hidden states from an RNN with variable length input. In the code example below:
lengthsis a list of length
batch_sizewith the sequence lengths for each element in the batch. It's a list because
pack_padded_sequencealso takes a list, so you already have it probably lying around.
batch_firstis a boolean indicating whether the RNN is in
batch_firstmode or not.
outputis the output of a PyTorch RNN as a
Variable. If your output isn't a
Variablefor some reason, just remove the
Variablecall in the last line on
idx = (torch.LongTensor(lengths) - 1).view(-1, 1).expand( len(lengths), output.size(2)) time_dimension = 1 if batch_first else 0 idx = idx.unsqueeze(time_dimension) if output.is_cuda: idx = idx.cuda(output.data.get_device()) # Shape: (batch_size, rnn_hidden_dim) last_output = output.gather( time_dimension, Variable(idx)).squeeze(time_dimension)
Here's a full code example with a RNN and variable-length input, adapted from an example on the PyTorch forums:
import torch from torch.autograd import Variable import torch.nn as nn batch_size = 4 max_length = 3 hidden_size = 2 n_layers = 1 input_dim = 1 batch_first = True # Data vec_1 = torch.FloatTensor([[1, 2, 3]]) vec_2 = torch.FloatTensor([[1, 2, 0]]) vec_3 = torch.FloatTensor([[1, 0, 0]]) vec_4 = torch.FloatTensor([[2, 0, 0]]) # Put the data into a tensor. batch_in = torch.zeros((batch_size, max_length, input_dim)) batch_in = vec_1 batch_in = vec_2 batch_in = vec_3 batch_in = vec_4 # Wrap RNN input in a Variable. Shape: (batch_size, max_length, input_dim) batch_in = Variable(batch_in) # The lengths of each example in the batch. Padding is 0. lengths = [3, 2, 1, 1] # Wrap input in packed sequence, with batch_first=True packed_input = torch.nn.utils.rnn.pack_padded_sequence( batch_in, seq_lengths, batch_first=True) # Create an RNN object, set batch_first=True rnn = nn.RNN(input_dim, hidden_size, n_layers, batch_first=True) # Run input through RNN packed_output, _ = rnn(packed_input) # Unpack, with batch_first=True. output, _ = torch.nn.utils.rnn.pad_packed_sequence( out, batch_first=True) print("Unpacked, padded output: ") print(output) # Extract the outputs for the last timestep of each example idx = (torch.LongTensor(lengths) - 1).view(-1, 1).expand( len(lengths), output.size(2)) time_dimension = 1 if batch_first else 0 idx = idx.unsqueeze(time_dimension) if output.is_cuda: idx = idx.cuda(output.data.get_device()) # Shape: (batch_size, rnn_hidden_dim) last_output = output.gather( time_dimension, Variable(idx)).squeeze(time_dimension) print("Last output: ") print(last_output)
and the output:
Unpacked, padded output: Variable containing: (0 ,.,.) = -0.0279 0.8709 0.7806 0.7903 0.5799 0.9227 (1 ,.,.) = 0.7244 0.7105 0.5795 0.8988 0.0000 0.0000 (2 ,.,.) = -0.7699 0.9169 0.0000 0.0000 0.0000 0.0000 (3 ,.,.) = 0.4918 0.4545 0.0000 0.0000 0.0000 0.0000 [torch.FloatTensor of size 4x3x2] Last output: Variable containing: 0.5799 0.9227 0.5795 0.8988 -0.7699 0.9169 0.4918 0.4545 [torch.FloatTensor of size 4x2]
As you can see, the code successfully extracted the last-timestep outputs for each example in the batch.
Some more context for those who might not be super familiar with PyTorch
PyTorch RNNs return a tuple of
outputcontains the hidden state of the last RNN layer at the last timestep --- this is usually what you want to pass downstream for sequence prediction tasks.
h_nis the hidden state for
t=seq_len(for all RNN layers and directions).
output is a tensor of shape
seq_len, batch_size, hidden_size * num_directions if
batch_first=False in the RNN, and it's a tensor of shape
batch_size, seq_len, hidden_size * num_directions if
If you're using a RNN with variable-length input (made possible with a
seq_len refers to the longest sequence in the
PackedSequence. In this case, you often need to extract the output features for each batch at the last timestep (where the "last timestep" is the length of the sequence for the particular example). Note that it doesn't work to simply use
output[:, -1] if
batch_first=True) since outputs beyond an example's length are padded with 0.