Here's some code I've been using to extract the last hidden states from an RNN with variable length input. In the code example below:

  • lengths is a list of length batch_size with the sequence lengths for each element in the batch. It's a list because pack_padded_sequence also takes a list, so you already have it probably lying around.
  • batch_first is a boolean indicating whether the RNN is in batch_first mode or not.
  • output is the output of a PyTorch RNN as a Variable. If your output isn't a Variable for some reason, just remove the Variable call in the last line on idx.
idx = (torch.LongTensor(lengths) - 1).view(-1, 1).expand(
    len(lengths), output.size(2))
time_dimension = 1 if batch_first else 0
idx = idx.unsqueeze(time_dimension)
if output.is_cuda:
    idx = idx.cuda(output.data.get_device())
# Shape: (batch_size, rnn_hidden_dim)
last_output = output.gather(
    time_dimension, Variable(idx)).squeeze(time_dimension)

Here's a full code example with a RNN and variable-length input, adapted from an example on the PyTorch forums:

import torch
from torch.autograd import Variable
import torch.nn as nn

batch_size = 4
max_length = 3
hidden_size = 2
n_layers = 1
input_dim = 1
batch_first = True

# Data
vec_1 = torch.FloatTensor([[1, 2, 3]])
vec_2 = torch.FloatTensor([[1, 2, 0]])
vec_3 = torch.FloatTensor([[1, 0, 0]])
vec_4 = torch.FloatTensor([[2, 0, 0]])

# Put the data into a tensor.
batch_in = torch.zeros((batch_size, max_length, input_dim))
batch_in[0] = vec_1
batch_in[1] = vec_2
batch_in[2] = vec_3
batch_in[3] = vec_4

# Wrap RNN input in a Variable. Shape: (batch_size, max_length, input_dim)
batch_in = Variable(batch_in)
# The lengths of each example in the batch. Padding is 0.
lengths = [3, 2, 1, 1]

# Wrap input in packed sequence, with batch_first=True
packed_input = torch.nn.utils.rnn.pack_padded_sequence(
    batch_in, seq_lengths, batch_first=True)

# Create an RNN object, set batch_first=True
rnn = nn.RNN(input_dim, hidden_size, n_layers, batch_first=True) 

# Run input through RNN 
packed_output, _ = rnn(packed_input)

# Unpack, with batch_first=True.
output, _ = torch.nn.utils.rnn.pad_packed_sequence(
    out, batch_first=True)
print("Unpacked, padded output: ")
print(output)

# Extract the outputs for the last timestep of each example
idx = (torch.LongTensor(lengths) - 1).view(-1, 1).expand(
    len(lengths), output.size(2))
time_dimension = 1 if batch_first else 0
idx = idx.unsqueeze(time_dimension)
if output.is_cuda:
    idx = idx.cuda(output.data.get_device())
# Shape: (batch_size, rnn_hidden_dim)
last_output = output.gather(
    time_dimension, Variable(idx)).squeeze(time_dimension)
print("Last output: ")
print(last_output)

and the output:

Unpacked, padded output:
Variable containing:
(0 ,.,.) =
 -0.0279  0.8709
  0.7806  0.7903
  0.5799  0.9227

(1 ,.,.) =
  0.7244  0.7105
  0.5795  0.8988
  0.0000  0.0000

(2 ,.,.) =
 -0.7699  0.9169
  0.0000  0.0000
  0.0000  0.0000

(3 ,.,.) =
  0.4918  0.4545
  0.0000  0.0000
  0.0000  0.0000
[torch.FloatTensor of size 4x3x2]

Last output:
Variable containing:
 0.5799  0.9227
 0.5795  0.8988
-0.7699  0.9169
 0.4918  0.4545
[torch.FloatTensor of size 4x2]

As you can see, the code successfully extracted the last-timestep outputs for each example in the batch.

Some more context for those who might not be super familiar with PyTorch

PyTorch RNNs return a tuple of (output, h_n):

  • output contains the hidden state of the last RNN layer at the last timestep --- this is usually what you want to pass downstream for sequence prediction tasks.
  • h_n is the hidden state for t=seq_len (for all RNN layers and directions).

output is a tensor of shape seq_len, batch_size, hidden_size * num_directions if batch_first=False in the RNN, and it's a tensor of shape batch_size, seq_len, hidden_size * num_directions if batch_first=False.

If you're using a RNN with variable-length input (made possible with a PackedSequence), seq_len refers to the longest sequence in the PackedSequence. In this case, you often need to extract the output features for each batch at the last timestep (where the "last timestep" is the length of the sequence for the particular example). Note that it doesn't work to simply use output[-1] (or output[:, -1] if batch_first=True) since outputs beyond an example's length are padded with 0.