foreword
This article mainly describes the two usages of torch.where(). The first one is the most conventional and is also indicated in the official document; the second one is to cooperate with the calculation of bool-type tensors.
1, torch.where() conventional usage
let's see first official document explanation of:
torch.where(condition, x, y)
According to the condition, that is condiction, returns a tensor of elements selected from x or y (a new tensor will be created here, the elements of the new tensor are selected from x or y, and the shape must conform to the broadcast of x and y condition).
Parameters are explained as follows:
1. condition (bool type tensor): when the condition is true, return the value of x, otherwise return the value of y
2. x (tensor or scalar): select the value of x when condition=True
2. y (tensor or scalar): select the value of y when condition=False
1.1 Same shape
First demonstrate the case of the same shape:
import torch x = torch.tensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]]) y = torch.tensor([[5, 6, 7], [7, 8, 9], [9, 10, 11]]) z = torch.where(x > 5, x, y) print(f'x = {x}') print(f'=========================') print(f'y = {y}') print(f'=========================') print(f'x > 5 = {x > 5}') print(f'=========================') print(f'z = {z}') >print result: x = tensor([[1, 2, 3], [3, 4, 5], [5, 6, 7]]) ========================= y = tensor([[ 5, 6, 7], [ 7, 8, 9], [ 9, 10, 11]]) ========================= x > 5 = tensor([[False, False, False], [False, False, False], [False, True, True]]) ========================= z = tensor([[5, 6, 7], [7, 8, 9], [9, 6, 7]])
The above defines x and y, both of which have the same shape shape=(3, 3), and then condition = x > 5 means that the value of each element in x must be greater than 5, here you can see row 0 in x Both row 1 and row 1 are False, only columns 1 and 2 of row 2 are True, and as I said earlier, when it is True, the value in x is used, and when it is False, the value in y is used, then the newly created The first two rows of z and the second row and column 0 use the value in y, and the remaining two use the value in x, and the shape of z is also (3, 3).
1.2 Scalar case
x = 3 y = torch.tensor([[1, 5, 7]]) z = torch.where(y > 2, y, x) print(f'y > 2 = {y > 2}') print(f'=========================') print(f'z = {z}') print(f'y > 2 = {y > 2}') print(f'=========================') print(f'z = {z}') >print result: y > 2 = tensor([[False, True, True]]) ========================= z = tensor([[3, 5, 7]])
Here, x is a scalar, condition = y > 2, if you ask me why not set the condition to condition = x > 2, it is very simple, x > 2 is not a bool Tensor. Here scalars and tensors can be broadcasted! !
example:
a = torch.tensor([1, 5, 7]) b = 3 c = a + b d = torch.tensor([3, 3, 3]) e = a + d print(f'c = {c}') print(f'e = {e}') >print result: c = tensor([ 4, 8, 10]) d = tensor([ 4, 8, 10])
In fact, b = 3 is pulled into [3, 3, 3], just like d.
1.3 Different shapes
In fact, the shape of the scalar is also different. Let me repeat it here, see an example:
x = torch.tensor([[1, 3, 5]]) y = torch.tensor([[2], [4], [6]]) z = torch.where(x > 2, x, y) print(f'x = {x}') print(f'=========================') print(f'y = {y}') print(f'=========================') print(f'x > 2 = {x > 2}') print(f'=========================') print(f'z = {z}') >print result: x = tensor([[1, 3, 5]]) ========================= y = tensor([[2], [4], [6]]) ========================= x > 2 = tensor([[False, True, True]]) ========================= z = tensor([[2, 3, 5], [4, 3, 5], [6, 3, 5]])
The above x.shape=(1, 3) y.shape=(3, 1), and then the condition = x > 2 shape=(1, 3) is broadcastable, so the operation can also be successful, and the torch is calculated. where(x > 2, x, y), broadcast x, y, and condition respectively, x.shape=(3, 3), y.shape=(3, 3), condition.shape=(3, 3 )
So the value of y replaces column 0, and columns 1 and 2 are the values of x.
Readers and friends are welcome to try more broadcasting forms by themselves
2. Special usage of torch.where()
torch.where(a & b)
Both a and b are bool Tensor, and a tuple is returned. The first item of the tuple is the Tensor of the index of the row where both a and b are True, and the second item is the Tensor of the index of the column where a and b are both True.
Please see the example:
a = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 1]], dtype=torch.bool) b = torch.ones((3, 3), dtype=torch.bool) c = torch.where(a & b) print(f'a = {a}') print(f'=========================') print(f'b = {b}') print(f'=========================') print(f'c = {c}') >print result: a = tensor([[False, True, True], [ True, False, False], [False, False, True]]) ========================= b = tensor([[True, True, True], [True, True, True], [True, True, True]]) ========================= c = (tensor([0, 0, 1, 2]), tensor([1, 2, 0, 2]))
c is a tuple, the 0th item is the row label where both a and b are True, and the 1st item is the column label where both a and b are True
Summarize
The above are the two usages of torch.where(). It seems to be more troublesome. If you practice more, it will be the same. The special point is a special usage of a broadcast mechanism. Comments and corrections are welcome!