# foreword

This article mainly describes the two usages of torch.where(). The first one is the most conventional and is also indicated in the official document; the second one is to cooperate with the calculation of bool-type tensors.

# 1, torch.where() conventional usage

let's see first official document explanation of:

torch.where(condition, x, y)
According to the condition, that is condiction, returns a tensor of elements selected from x or y (a new tensor will be created here, the elements of the new tensor are selected from x or y, and the shape must conform to the broadcast of x and y condition).
Parameters are explained as follows:
1. condition (bool type tensor): when the condition is true, return the value of x, otherwise return the value of y
2. x (tensor or scalar): select the value of x when condition=True
2. y (tensor or scalar): select the value of y when condition=False

## 1.1 Same shape

First demonstrate the case of the same shape:

```import torch

x = torch.tensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
y = torch.tensor([[5, 6, 7], [7, 8, 9], [9, 10, 11]])
z = torch.where(x > 5, x, y)

print(f'x = {x}')
print(f'=========================')
print(f'y = {y}')
print(f'=========================')
print(f'x > 5 = {x > 5}')
print(f'=========================')
print(f'z = {z}')

>print result:
x = tensor([[1, 2, 3],
[3, 4, 5],
[5, 6, 7]])
=========================
y = tensor([[ 5,  6,  7],
[ 7,  8,  9],
[ 9, 10, 11]])
=========================
x > 5 = tensor([[False, False, False],
[False, False, False],
[False,  True,  True]])
=========================
z = tensor([[5, 6, 7],
[7, 8, 9],
[9, 6, 7]])
```

The above defines x and y, both of which have the same shape shape=(3, 3), and then condition = x > 5 means that the value of each element in x must be greater than 5, here you can see row 0 in x Both row 1 and row 1 are False, only columns 1 and 2 of row 2 are True, and as I said earlier, when it is True, the value in x is used, and when it is False, the value in y is used, then the newly created The first two rows of z and the second row and column 0 use the value in y, and the remaining two use the value in x, and the shape of z is also (3, 3).

## 1.2 Scalar case

```x = 3
y = torch.tensor([[1, 5, 7]])
z = torch.where(y > 2, y, x)

print(f'y > 2 = {y > 2}')
print(f'=========================')
print(f'z = {z}')

print(f'y > 2 = {y > 2}')
print(f'=========================')
print(f'z = {z}')

>print result:
y > 2 = tensor([[False,  True,  True]])
=========================
z = tensor([[3, 5, 7]])
```

Here, x is a scalar, condition = y > 2, if you ask me why not set the condition to condition = x > 2, it is very simple, x > 2 is not a bool Tensor. Here scalars and tensors can be broadcasted! !

example:

```a = torch.tensor([1, 5, 7])
b = 3
c = a + b
d = torch.tensor([3, 3, 3])
e = a + d

print(f'c = {c}')
print(f'e = {e}')

>print result:
c = tensor([ 4,  8, 10])
d = tensor([ 4,  8, 10])
```

In fact, b = 3 is pulled into [3, 3, 3], just like d.

## 1.3 Different shapes

In fact, the shape of the scalar is also different. Let me repeat it here, see an example:

```x = torch.tensor([[1, 3, 5]])
y = torch.tensor([[2], [4], [6]])
z = torch.where(x > 2, x, y)

print(f'x = {x}')
print(f'=========================')
print(f'y = {y}')
print(f'=========================')
print(f'x > 2 = {x > 2}')
print(f'=========================')
print(f'z = {z}')

>print result:
x = tensor([[1, 3, 5]])
=========================
y = tensor([[2],
[4],
[6]])
=========================
x > 2 = tensor([[False,  True,  True]])
=========================
z = tensor([[2, 3, 5],
[4, 3, 5],
[6, 3, 5]])
```

The above x.shape=(1, 3) y.shape=(3, 1), and then the condition = x > 2 shape=(1, 3) is broadcastable, so the operation can also be successful, and the torch is calculated. where(x > 2, x, y), broadcast x, y, and condition respectively, x.shape=(3, 3), y.shape=(3, 3), condition.shape=(3, 3 )

So the value of y replaces column 0, and columns 1 and 2 are the values ​​of x.
Readers and friends are welcome to try more broadcasting forms by themselves

# 2. Special usage of torch.where()

torch.where(a & b)
Both a and b are bool Tensor, and a tuple is returned. The first item of the tuple is the Tensor of the index of the row where both a and b are True, and the second item is the Tensor of the index of the column where a and b are both True.

```a = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 1]], dtype=torch.bool)
b = torch.ones((3, 3), dtype=torch.bool)
c = torch.where(a & b)

print(f'a = {a}')
print(f'=========================')
print(f'b = {b}')
print(f'=========================')
print(f'c = {c}')

>print result:
a = tensor([[False,  True,  True],
[ True, False, False],
[False, False,  True]])
=========================
b = tensor([[True, True, True],
[True, True, True],
[True, True, True]])
=========================
c = (tensor([0, 0, 1, 2]), tensor([1, 2, 0, 2]))
```

c is a tuple, the 0th item is the row label where both a and b are True, and the 1st item is the column label where both a and b are True

# Summarize

The above are the two usages of torch.where(). It seems to be more troublesome. If you practice more, it will be the same. The special point is a special usage of a broadcast mechanism. Comments and corrections are welcome!

Posted by Crysma on Thu, 05 Jan 2023 10:48:01 +0530