Advanced masking, filling, selecting, and indexing PyTorch tensors
import torch
Masking
= torch.arange(6).reshape(3, 2)
x = torch.tensor([
mask 0, 0],
[0, 1],
[1, 0],
[bool() ]).
x
tensor([[0, 1],
[2, 3],
[4, 5]])
Normal masking
* mask x
tensor([[0, 0],
[0, 3],
[4, 0]])
Filling
= torch.arange(6).reshape(3, 2) x
Masking and filling in with a single value
This is often used for padding tensors with FILL_VALUE
as the padding value
= -1
FILL_VALUE x.masked_fill(mask, PAD_VALUE)
tensor([[ 0, 1],
[ 2, -1],
[-1, 5]])
Fill with a single value along a dimension
= torch.tensor([0, 2])
idx = 0
dim =dim, index=idx, value=FILL_VALUE) x.index_fill(dim
tensor([[-1, -1],
[ 2, 3],
[-1, -1]])
Put multiple entries of a tensor in original tensor
x
= torch.tensor([
mask 0, 0],
[0, 1],
[1, 0],
[bool()
]).= torch.tensor([-10, -20])
values
= values x[mask]
tensor([[ 0, 1],
[ 2, -10],
[-20, 5]])
Note that x[mask] = values
is in-place! Behind the
scenes, this converts the mask into a tuple of indexing tensors, then
calls index_put_
Filling a mask with a fixed value using scatter
= torch.zeros(2, 3)
mask = torch.tensor([
idxs 0, 2],
[1, 1]
[
])= 1
dim = 1.0
value mask.scatter(dim, idxs, value)
tensor([[1., 0., 1.],
[0., 1., 0.]])
Filling a tensor with different values using scatter
idx
and src
should have the same shape
idx
and target
should have same dimensions
everywhere except along dim
= torch.zeros(2, 3)
target = torch.tensor([
idxs 0, 2],
[1, 1]
[
])= 1
dim = torch.tensor([
src 10., 30.],
[20., 20.],
[
]) target.scatter(dim, idxs, src)
tensor([[10., 0., 30.],
[ 0., 20., 0.]])
Filling a tensor with masked scatter (scatter into places where mask is true)
Note that source
should be just a 1d tensor, but it can
have more values than number of true entries in mask
If
source
is not a 2d tensor, masked_scatter will flatten
source
first
= torch.arange(6).reshape(3, 2)
x = torch.tensor([
mask 0, 0],
[0, 1],
[1, 0],
[bool()
]).
= torch.tensor([
source 10, 20, 30, 40, 50
])
x.masked_scatter(mask, source)
tensor([[ 0, 1],
[ 2, 10],
[20, 5]])
Select and indexing
Some terminology:
By "entry" we mean a position in the tensor, like position
(i,j)
By "value" we mean the data in a particular entry, like
tsr[i,j])
= torch.arange(6).reshape(3, 2)
x x
tensor([[0, 1],
[2, 3],
[4, 5]])
Select along a particular dimension
Each entry of idxs
corresponds to an entry along
dim
= 0
dim = torch.tensor([0, 2, 0, 0])
idxs =dim, index=idxs) x.index_select(dim
tensor([[0, 1],
[4, 5],
[0, 1],
[0, 1]])
Select along a particular dimension (equivalent, via fancy indexing)
Each entry of idxs
corresponds to an entry along
dim
= torch.tensor([0, 2, 0, 0])
idxs x[idxs]
tensor([[0, 1],
[4, 5],
[0, 1],
[0, 1]])
Select along a particular dimension (equivalent, via fancy indexing, but now along columns)
Each entry of idxs
corresponds to an entry along
dim
= torch.tensor([1, 0, 0])
idxs x[:, idxs]
tensor([[1, 0, 0],
[3, 2, 2],
[5, 4, 4]])
Selecting according to entries of a mask
torch.masked_select
has the following signature
type:
torch.masked_select(
input: TensorType[D1, D2, D3, ..., DN],
mask: TensorType[D1, D2, D3, ..., DN]
): -> out: TensorType[D]
out
will contain the entries of input
where
mask
is true.
= torch.tensor([
mask 0, 0],
[0, 1],
[1, 0],
[bool()
]). x.masked_select(mask)
tensor([3, 4])
Selecting according to a mask but using fancy indexing
Use masked select if you just want a list of entries. If you want to
keep the shape of the original tensor x
, you should use
normal masking.
= torch.tensor([
mask 0, 0],
[0, 1],
[1, 0],
[bool()
]). x[mask]
tensor([3, 4])
Selecting so that out[i,j] = input[row[i, j], cols[i,
j]]
out.shape
will be the same as row.shape
.
Also row.shape
must be equal to cols.shape
(or
at least broadcastable, see below)
= torch.tensor([
rows 2, 2, 2],
[0, 0, 0],
[
])= torch.tensor([
cols 0, 1, 0],
[0, 1, 0],
[
]) x[rows, cols]
tensor([[4, 5, 4],
[0, 1, 0]])
Selecting so that out[i,j] = input[row[i], cols[j]] (recommended way using broadcasting)
= torch.tensor([2, 0])
rows = torch.tensor([0, 1, 0])
cols 1), cols.unsqueeze(0)] x[rows.unsqueeze(
tensor([[4, 5, 4],
[0, 1, 0]])
Note that x[rows.unsqueeze(1), cols.unsqueeze(0)]
broadcasts rows
and cols
to be the same shape
as each other
Convert a mask into indices (list of nonzero entries)
= torch.tensor([
mask 0, 0],
[0, 1],
[1, 0],
[bool()
]).= mask.nonzero()
nonzeros print(nonzeros)
# then indexing using `nonzeros` by turning it into a pair of list of indices
= tuple(nonzeros.T)
pair print(pair)
x[pair]
tensor([[1, 1],
[2, 0]])
(tensor([1, 2]), tensor([1, 0]))
tensor([3, 4])
Selecting with a single Long tensor
Each value in tsr
will be read as an entry along dim 0
of x
, so that the resulting shape is (*tsr.shape,
x.shape[1:])
.
= torch.tensor([
tsr 1, 1],
[2, 0],
[
]) x[tsr]
tensor([[[2, 3],
[2, 3]],
[[4, 5],
[0, 1]]])
Selecting with a single Long tensor (equivalent, but don't do this!)
*tsr.shape, x.shape[1]) x[tsr.flatten()].reshape(
tensor([[[2, 3],
[2, 3]],
[[4, 5],
[0, 1]]])
Indexing with slices :
and ellipsises
...
Note that pytorch may be slower with slicing
Indexing with slicing
= [1, 0, 1, 0, 0]
idxs x[:, idxs]
tensor([[1, 0, 1, 0, 0],
[3, 2, 3, 2, 2],
[5, 4, 5, 4, 4]])
This is equivalent to indexing with a tuple of indexing arrays that
get broadcasted. Think about slicing as equivalent to indexing with
torch.arange
0)).unsqueeze(1), torch.tensor(idxs).unsqueeze(0)] x[torch.arange(x.size(
tensor([[1, 0, 1, 0, 0],
[3, 2, 3, 2, 2],
[5, 4, 5, 4, 4]])
Indexing using gather
Think about gather as a multidimensional version of selecting with broadcasting a tuple of indexing arrays. Consider:
== (N x D)
x.shape 1), torch.tensor([1, 2].unsqueeze(0))] x[torch.arange(N).unsqueeze(
Here we gather along dim==1
, resulting in an output
shape of N x 2
Now suppose we had a B x T x D
tensor representing a
batch of RNN outputs each with dimension D
Let's say we
wanted to select only the last timestep of each sequence of each batch
and that we knew how long each sequence was for each batch. We can do
the following:
= torch.tensor([
tsr 1., 2., 3., 4.],
[[5., 6., 7., 8.],
[0., 0., 0., 0.],
[
],9., 8., 7., 6.],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[
],# use 0 to indicate padding
]) = tsr.shape
N, T, D
= torch.tensor([2, 1]) # length of each of the N sequences
lengths
= lengths - 1
last_idxs = last_idxs.unsqueeze(-1).expand(N, D).unsqueeze(-2) # B x 1 x D
index print(index)
=index, dim=1) tsr.gather(index
tensor([[[1, 1, 1, 1]],
[[0, 0, 0, 0]]])
tensor([[[5., 6., 7., 8.]],
[[9., 8., 7., 6.]]])
The output size will be the same as index
size in all
dimensions except dim
. index
must have the
same size in all dimensions as tsr
except dimension
dim
. dim
specifies the dimension along which
the values in index
will be indexing
Indexing timings
Let's see which of these equivalent gather methods are faster!
def batch_gather(tensor, indices):
= []
output for i in range(tensor.size(0)):
+= [tensor[i][indices[i]]]
output return torch.stack(output)
def batch_gather_vec(tensor, indices):
= list(tensor.shape)
shape = torch.reshape(
flat_first 0] * shape[1]] + shape[2:])
tensor, [shape[= torch.reshape(
offset 0], device=tensor.device) * shape[1],
torch.arange(shape[0]] + [1] * (len(indices.shape) - 1))
[shape[= flat_first[indices + offset]
output return output
= "cuda"
device = torch.randn(1000, 200, device=device)
x = torch.randint(200, (1000, ), device=device)
idx = idx.tolist() idx_lst
%%timeit
sum() batch_gather(x, idx).
10.6 ms ± 36.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
sum() batch_gather_vec(x, idx).
51.4 µs ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%%timeit
=idx.unsqueeze(-1), dim=-1).squeeze(-1).sum() x.gather(index
19.6 µs ± 332 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Clearly the gather function built into pytorch is the fastest.
How about the timing for summing along an axis/dimension?
%%timeit
= 0
res for i in range(len(x)):
= res + x[i] res
6.95 ms ± 50.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
= 0
res for row in x:
= res + row res
5 ms ± 39.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
= 0
res for row in x.unbind():
= res + row res
4.99 ms ± 26.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
sum(0) x.
15 µs ± 17.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
What about indexing?
%%timeit
x[idx_lst]
54.3 µs ± 568 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%%timeit
=idx, dim=0) x.index_select(index
7.03 µs ± 9.33 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
x[idx]
9.39 µs ± 62.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)