Advanced masking, filling, selecting, and indexing PyTorch tensors

Ke Alexander Wang



import torch

Masking

x = torch.arange(6).reshape(3, 2)
mask = torch.tensor([
    [0, 0],
    [0, 1],
    [1, 0],
]).bool()
x
tensor([[0, 1],
        [2, 3],
        [4, 5]])

Normal masking

x * mask
tensor([[0, 0],
        [0, 3],
        [4, 0]])

Filling

x = torch.arange(6).reshape(3, 2)

Masking and filling in with a single value

This is often used for padding tensors with FILL_VALUE as the padding value

FILL_VALUE = -1
x.masked_fill(mask, PAD_VALUE)
tensor([[ 0,  1],
        [ 2, -1],
        [-1,  5]])

Fill with a single value along a dimension

idx = torch.tensor([0, 2])
dim = 0
x.index_fill(dim=dim, index=idx, value=FILL_VALUE)
tensor([[-1, -1],
        [ 2,  3],
        [-1, -1]])

Put multiple entries of a tensor in original tensor x

mask = torch.tensor([
    [0, 0],
    [0, 1],
    [1, 0],
]).bool()
values = torch.tensor([-10, -20])

x[mask] = values
tensor([[  0,   1],
        [  2, -10],
        [-20,   5]])

Note that x[mask] = values is in-place! Behind the scenes, this converts the mask into a tuple of indexing tensors, then calls index_put_

Filling a mask with a fixed value using scatter

mask = torch.zeros(2, 3)
idxs = torch.tensor([
    [0, 2],
    [1, 1]
])
dim = 1
value = 1.0
mask.scatter(dim, idxs, value)
tensor([[1., 0., 1.],
        [0., 1., 0.]])

Filling a tensor with different values using scatter

idx and src should have the same shape idx and target should have same dimensions everywhere except along dim

target = torch.zeros(2, 3)
idxs = torch.tensor([
    [0, 2],
    [1, 1]
])
dim = 1
src = torch.tensor([
    [10., 30.],
    [20., 20.],
])
target.scatter(dim, idxs, src)
tensor([[10.,  0., 30.],
        [ 0., 20.,  0.]])

Filling a tensor with masked scatter (scatter into places where mask is true)

Note that source should be just a 1d tensor, but it can have more values than number of true entries in mask If source is not a 2d tensor, masked_scatter will flatten source first

x = torch.arange(6).reshape(3, 2)
mask = torch.tensor([
    [0, 0],
    [0, 1],
    [1, 0],
]).bool()

source = torch.tensor([
    10, 20, 30, 40, 50
])

x.masked_scatter(mask, source)
tensor([[ 0,  1],
        [ 2, 10],
        [20,  5]])

Select and indexing

Some terminology:

x = torch.arange(6).reshape(3, 2)
x
tensor([[0, 1],
        [2, 3],
        [4, 5]])

Select along a particular dimension

Each entry of idxs corresponds to an entry along dim

dim = 0
idxs = torch.tensor([0, 2, 0, 0])
x.index_select(dim=dim, index=idxs)
tensor([[0, 1],
        [4, 5],
        [0, 1],
        [0, 1]])

Select along a particular dimension (equivalent, via fancy indexing)

Each entry of idxs corresponds to an entry along dim

idxs = torch.tensor([0, 2, 0, 0])
x[idxs]
tensor([[0, 1],
        [4, 5],
        [0, 1],
        [0, 1]])

Select along a particular dimension (equivalent, via fancy indexing, but now along columns)

Each entry of idxs corresponds to an entry along dim

idxs = torch.tensor([1, 0, 0])
x[:, idxs]
tensor([[1, 0, 0],
        [3, 2, 2],
        [5, 4, 4]])

Selecting according to entries of a mask

torch.masked_select has the following signature type:

torch.masked_select(
    input: TensorType[D1, D2, D3, ..., DN],
    mask: TensorType[D1, D2, D3, ..., DN]
): -> out: TensorType[D]

out will contain the entries of input where mask is true.

mask = torch.tensor([
    [0, 0],
    [0, 1],
    [1, 0],
]).bool()
x.masked_select(mask)
tensor([3, 4])

Selecting according to a mask but using fancy indexing

Use masked select if you just want a list of entries. If you want to keep the shape of the original tensor x, you should use normal masking.

mask = torch.tensor([
    [0, 0],
    [0, 1],
    [1, 0],
]).bool()
x[mask]
tensor([3, 4])

Selecting so that out[i,j] = input[row[i, j], cols[i, j]]

out.shape will be the same as row.shape. Also row.shape must be equal to cols.shape (or at least broadcastable, see below)

rows = torch.tensor([
    [2, 2, 2],
    [0, 0, 0],
])
cols = torch.tensor([
    [0, 1, 0],
    [0, 1, 0],
])
x[rows, cols]
tensor([[4, 5, 4],
        [0, 1, 0]])
rows = torch.tensor([2, 0])
cols = torch.tensor([0, 1, 0])
x[rows.unsqueeze(1), cols.unsqueeze(0)]
tensor([[4, 5, 4],
        [0, 1, 0]])

Note that x[rows.unsqueeze(1), cols.unsqueeze(0)] broadcasts rows and cols to be the same shape as each other

Convert a mask into indices (list of nonzero entries)

mask = torch.tensor([
    [0, 0],
    [0, 1],
    [1, 0],
]).bool()
nonzeros = mask.nonzero()
print(nonzeros)

# then indexing using `nonzeros` by turning it into a pair of list of indices
pair = tuple(nonzeros.T)
print(pair)
x[pair]
tensor([[1, 1],
        [2, 0]])
(tensor([1, 2]), tensor([1, 0]))
tensor([3, 4])

Selecting with a single Long tensor

Each value in tsr will be read as an entry along dim 0 of x, so that the resulting shape is (*tsr.shape, x.shape[1:]).

tsr = torch.tensor([
    [1, 1],
    [2, 0],
])
x[tsr]
tensor([[[2, 3],
         [2, 3]],

        [[4, 5],
         [0, 1]]])

Selecting with a single Long tensor (equivalent, but don't do this!)

x[tsr.flatten()].reshape(*tsr.shape, x.shape[1])
tensor([[[2, 3],
         [2, 3]],

        [[4, 5],
         [0, 1]]])

Indexing with slices : and ellipsises ...

Note that pytorch may be slower with slicing

Indexing with slicing

idxs = [1, 0, 1, 0, 0]
x[:, idxs]
tensor([[1, 0, 1, 0, 0],
        [3, 2, 3, 2, 2],
        [5, 4, 5, 4, 4]])

This is equivalent to indexing with a tuple of indexing arrays that get broadcasted. Think about slicing as equivalent to indexing with torch.arange

x[torch.arange(x.size(0)).unsqueeze(1), torch.tensor(idxs).unsqueeze(0)]
tensor([[1, 0, 1, 0, 0],
        [3, 2, 3, 2, 2],
        [5, 4, 5, 4, 4]])

Indexing using gather

Think about gather as a multidimensional version of selecting with broadcasting a tuple of indexing arrays. Consider:

x.shape == (N x D)
x[torch.arange(N).unsqueeze(1), torch.tensor([1, 2].unsqueeze(0))]

Here we gather along dim==1, resulting in an output shape of N x 2

Now suppose we had a B x T x D tensor representing a batch of RNN outputs each with dimension D Let's say we wanted to select only the last timestep of each sequence of each batch and that we knew how long each sequence was for each batch. We can do the following:

tsr = torch.tensor([
    [[1., 2., 3., 4.],
     [5., 6., 7., 8.],
     [0., 0., 0., 0.],
    ],
    [[9., 8., 7., 6.],
     [0., 0., 0., 0.],
     [0., 0., 0., 0.],
    ],
]) # use 0 to indicate padding
N, T, D = tsr.shape

lengths = torch.tensor([2, 1])  # length of each of the N sequences

last_idxs = lengths - 1
index = last_idxs.unsqueeze(-1).expand(N, D).unsqueeze(-2)  # B x 1 x D
print(index)
tsr.gather(index=index, dim=1)
tensor([[[1, 1, 1, 1]],

        [[0, 0, 0, 0]]])
tensor([[[5., 6., 7., 8.]],

        [[9., 8., 7., 6.]]])

The output size will be the same as index size in all dimensions except dim. index must have the same size in all dimensions as tsr except dimension dim. dim specifies the dimension along which the values in index will be indexing

Indexing timings

Let's see which of these equivalent gather methods are faster!

def batch_gather(tensor, indices):
    output = []
    for i in range(tensor.size(0)):
        output += [tensor[i][indices[i]]]
    return torch.stack(output)
def batch_gather_vec(tensor, indices):
    shape = list(tensor.shape)
    flat_first = torch.reshape(
        tensor, [shape[0] * shape[1]] + shape[2:])
    offset = torch.reshape(
        torch.arange(shape[0], device=tensor.device) * shape[1],
        [shape[0]] + [1] * (len(indices.shape) - 1))
    output = flat_first[indices + offset]
    return output
device = "cuda"
x = torch.randn(1000, 200, device=device)
idx = torch.randint(200, (1000, ), device=device)
idx_lst = idx.tolist()
%%timeit
batch_gather(x, idx).sum()
10.6 ms ± 36.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
batch_gather_vec(x, idx).sum()
51.4 µs ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%%timeit
x.gather(index=idx.unsqueeze(-1), dim=-1).squeeze(-1).sum()
19.6 µs ± 332 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

Clearly the gather function built into pytorch is the fastest.

How about the timing for summing along an axis/dimension?

%%timeit
res = 0
for i in range(len(x)):
    res = res + x[i]
6.95 ms ± 50.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
res = 0
for row in x:
    res = res + row
5 ms ± 39.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
res = 0
for row in x.unbind():
    res = res + row
4.99 ms ± 26.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
x.sum(0)
15 µs ± 17.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

What about indexing?

%%timeit
x[idx_lst]
54.3 µs ± 568 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%%timeit
x.index_select(index=idx, dim=0)
7.03 µs ± 9.33 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
x[idx]
9.39 µs ± 62.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)