[Torch API] Detailed explanation of the torch.where() function in pytorch

foreword

This article mainly describes the two usages of torch.where(). The first one is the most conventional and is also indicated in the official document; the second one is to cooperate with the calculation of bool-type tensors.

1, torch.where() conventional usage

let's see first official document explanation of:

torch.where(condition, x, y)
According to the condition, that is condiction, returns a tensor of elements selected from x or y (a new tensor will be created here, the elements of the new tensor are selected from x or y, and the shape must conform to the broadcast of x and y condition).
Parameters are explained as follows:
1. condition (bool type tensor): when the condition is true, return the value of x, otherwise return the value of y
2. x (tensor or scalar): select the value of x when condition=True
2. y (tensor or scalar): select the value of y when condition=False

1.1 Same shape

First demonstrate the case of the same shape:

import torch

x = torch.tensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
y = torch.tensor([[5, 6, 7], [7, 8, 9], [9, 10, 11]])
z = torch.where(x > 5, x, y)

print(f'x = {x}')
print(f'=========================')
print(f'y = {y}')
print(f'=========================')
print(f'x > 5 = {x > 5}')
print(f'=========================')
print(f'z = {z}')

>print result:
x = tensor([[1, 2, 3],
        [3, 4, 5],
        [5, 6, 7]])
=========================
y = tensor([[ 5,  6,  7],
        [ 7,  8,  9],
        [ 9, 10, 11]])
=========================
x > 5 = tensor([[False, False, False],
        [False, False, False],
        [False,  True,  True]])
=========================
z = tensor([[5, 6, 7],
        [7, 8, 9],
        [9, 6, 7]])

The above defines x and y, both of which have the same shape shape=(3, 3), and then condition = x > 5 means that the value of each element in x must be greater than 5, here you can see row 0 in x Both row 1 and row 1 are False, only columns 1 and 2 of row 2 are True, and as I said earlier, when it is True, the value in x is used, and when it is False, the value in y is used, then the newly created The first two rows of z and the second row and column 0 use the value in y, and the remaining two use the value in x, and the shape of z is also (3, 3).
 

1.2 Scalar case

x = 3
y = torch.tensor([[1, 5, 7]])
z = torch.where(y > 2, y, x)

print(f'y > 2 = {y > 2}')
print(f'=========================')
print(f'z = {z}')

print(f'y > 2 = {y > 2}')
print(f'=========================')
print(f'z = {z}')

>print result:
y > 2 = tensor([[False,  True,  True]])
=========================
z = tensor([[3, 5, 7]])

Here, x is a scalar, condition = y > 2, if you ask me why not set the condition to condition = x > 2, it is very simple, x > 2 is not a bool Tensor. Here scalars and tensors can be broadcasted! !

example:

a = torch.tensor([1, 5, 7])
b = 3
c = a + b
d = torch.tensor([3, 3, 3])
e = a + d

print(f'c = {c}')
print(f'e = {e}')

>print result:
c = tensor([ 4,  8, 10])
d = tensor([ 4,  8, 10])

In fact, b = 3 is pulled into [3, 3, 3], just like d.

1.3 Different shapes

In fact, the shape of the scalar is also different. Let me repeat it here, see an example:

x = torch.tensor([[1, 3, 5]])
y = torch.tensor([[2], [4], [6]])
z = torch.where(x > 2, x, y)

print(f'x = {x}')
print(f'=========================')
print(f'y = {y}')
print(f'=========================')
print(f'x > 2 = {x > 2}')
print(f'=========================')
print(f'z = {z}')

>print result:
x = tensor([[1, 3, 5]])
=========================
y = tensor([[2],
        [4],
        [6]])
=========================
x > 2 = tensor([[False,  True,  True]])
=========================
z = tensor([[2, 3, 5],
        [4, 3, 5],
        [6, 3, 5]])

The above x.shape=(1, 3) y.shape=(3, 1), and then the condition = x > 2 shape=(1, 3) is broadcastable, so the operation can also be successful, and the torch is calculated. where(x > 2, x, y), broadcast x, y, and condition respectively, x.shape=(3, 3), y.shape=(3, 3), condition.shape=(3, 3 )

So the value of y replaces column 0, and columns 1 and 2 are the values ​​of x.
Readers and friends are welcome to try more broadcasting forms by themselves

2. Special usage of torch.where()

torch.where(a & b)
Both a and b are bool Tensor, and a tuple is returned. The first item of the tuple is the Tensor of the index of the row where both a and b are True, and the second item is the Tensor of the index of the column where a and b are both True.

Please see the example:

a = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 1]], dtype=torch.bool)
b = torch.ones((3, 3), dtype=torch.bool)
c = torch.where(a & b)

print(f'a = {a}')
print(f'=========================')
print(f'b = {b}')
print(f'=========================')
print(f'c = {c}')

>print result:
a = tensor([[False,  True,  True],
        [ True, False, False],
        [False, False,  True]])
=========================
b = tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])
=========================
c = (tensor([0, 0, 1, 2]), tensor([1, 2, 0, 2]))

c is a tuple, the 0th item is the row label where both a and b are True, and the 1st item is the column label where both a and b are True

Summarize

The above are the two usages of torch.where(). It seems to be more troublesome. If you practice more, it will be the same. The special point is a special usage of a broadcast mechanism. Comments and corrections are welcome!

Tags: Python Deep Learning Pytorch

Posted by Crysma on Thu, 05 Jan 2023 10:48:01 +0530