# foreword

This article mainly describes the two usages of torch.where(). The first one is the most conventional and is also indicated in the official document; the second one is to cooperate with the calculation of bool-type tensors.

# 1, torch.where() conventional usage

let's see first official document explanation of:

torch.where(condition, x, y)

According to the condition, that is condiction, returns a tensor of elements selected from x or y (a new tensor will be created here, the elements of the new tensor are selected from x or y, and the shape must conform to the broadcast of x and y condition).

Parameters are explained as follows:

1. condition (bool type tensor): when the condition is true, return the value of x, otherwise return the value of y

2. x (tensor or scalar): select the value of x when condition=True

2. y (tensor or scalar): select the value of y when condition=False

## 1.1 Same shape

First demonstrate the case of the same shape:

import torch x = torch.tensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]]) y = torch.tensor([[5, 6, 7], [7, 8, 9], [9, 10, 11]]) z = torch.where(x > 5, x, y) print(f'x = {x}') print(f'=========================') print(f'y = {y}') print(f'=========================') print(f'x > 5 = {x > 5}') print(f'=========================') print(f'z = {z}') >print result: x = tensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]]) ========================= y = tensor([[ 5, 6, 7], [ 7, 8, 9], [ 9, 10, 11]]) ========================= x > 5 = tensor([[False, False, False], [False, False, False], [False, True, True]]) ========================= z = tensor([[5, 6, 7], [7, 8, 9], [9, 6, 7]])

The above defines x and y, both of which have the same shape shape=(3, 3), and then condition = x > 5 means that the value of each element in x must be greater than 5, here you can see row 0 in x Both row 1 and row 1 are False, only columns 1 and 2 of row 2 are True, and as I said earlier, when it is True, the value in x is used, and when it is False, the value in y is used, then the newly created The first two rows of z and the second row and column 0 use the value in y, and the remaining two use the value in x, and the shape of z is also (3, 3).

## 1.2 Scalar case

x = 3 y = torch.tensor([[1, 5, 7]]) z = torch.where(y > 2, y, x) print(f'y > 2 = {y > 2}') print(f'=========================') print(f'z = {z}') print(f'y > 2 = {y > 2}') print(f'=========================') print(f'z = {z}') >print result: y > 2 = tensor([[False, True, True]]) ========================= z = tensor([[3, 5, 7]])

Here, x is a scalar, condition = y > 2, if you ask me why not set the condition to condition = x > 2, it is very simple, x > 2 is not a bool Tensor. Here scalars and tensors can be broadcasted! !

example:

a = torch.tensor([1, 5, 7]) b = 3 c = a + b d = torch.tensor([3, 3, 3]) e = a + d print(f'c = {c}') print(f'e = {e}') >print result: c = tensor([ 4, 8, 10]) d = tensor([ 4, 8, 10])

In fact, b = 3 is pulled into [3, 3, 3], just like d.

## 1.3 Different shapes

In fact, the shape of the scalar is also different. Let me repeat it here, see an example:

x = torch.tensor([[1, 3, 5]]) y = torch.tensor([[2], [4], [6]]) z = torch.where(x > 2, x, y) print(f'x = {x}') print(f'=========================') print(f'y = {y}') print(f'=========================') print(f'x > 2 = {x > 2}') print(f'=========================') print(f'z = {z}') >print result: x = tensor([[1, 3, 5]]) ========================= y = tensor([[2], [4], [6]]) ========================= x > 2 = tensor([[False, True, True]]) ========================= z = tensor([[2, 3, 5], [4, 3, 5], [6, 3, 5]])

The above x.shape=(1, 3) y.shape=(3, 1), and then the condition = x > 2 shape=(1, 3) is broadcastable, so the operation can also be successful, and the torch is calculated. where(x > 2, x, y), broadcast x, y, and condition respectively, x.shape=(3, 3), y.shape=(3, 3), condition.shape=(3, 3 )

So the value of y replaces column 0, and columns 1 and 2 are the values of x.

Readers and friends are welcome to try more broadcasting forms by themselves

# 2. Special usage of torch.where()

torch.where(a & b)

Both a and b are bool Tensor, and a tuple is returned. The first item of the tuple is the Tensor of the index of the row where both a and b are True, and the second item is the Tensor of the index of the column where a and b are both True.

Please see the example:

a = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 1]], dtype=torch.bool) b = torch.ones((3, 3), dtype=torch.bool) c = torch.where(a & b) print(f'a = {a}') print(f'=========================') print(f'b = {b}') print(f'=========================') print(f'c = {c}') >print result: a = tensor([[False, True, True], [ True, False, False], [False, False, True]]) ========================= b = tensor([[True, True, True], [True, True, True], [True, True, True]]) ========================= c = (tensor([0, 0, 1, 2]), tensor([1, 2, 0, 2]))

c is a tuple, the 0th item is the row label where both a and b are True, and the 1st item is the column label where both a and b are True

# Summarize

The above are the two usages of torch.where(). It seems to be more troublesome. If you practice more, it will be the same. The special point is a special usage of a broadcast mechanism. Comments and corrections are welcome!