본문 바로가기

Machine Learning Models/Pytorch

Pytorch - scatter

반응형

Pytorch 에는 새롭게 구성한 텐서에 원하는 인덱스에 맞게 값을 할당해주는 scatter 함수가 존재합니다. (Tensorflow 에도 존재합니다.) 이번 포스트에서는 scatter 함수의 동작원리에 대해 알아보도록 하겠습니다. 먼저 Pytorch 공식 문서에는 scatter 함수가 다음과 같이 in-place 함수로 정의되어 있습니다.

scatter_(dim, index, src, reduce=None) → Tensor

언더바가 붙었으므로 "tensor.scatter_()" 형태로 동작하며, 파라미터로 주어진 "index" 에 맞게 "src" 의 값을 새로운 "tensor" 로 할당합니다. 예를 들어 3차원 텐서라면 다음과 같이 업데이트 됩니다.

tensor[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
tensor[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
tensor[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

정의된 파라미터는 다음과 같습니다.

  • dim (int): 인덱싱의 기준이 되는 축
  • index (LongTensor): "src" 구성요소들이 흩어질 기준이 되는 인덱스 텐서로 정수형 텐서로 구성되어야 합니다. 
  • src (Tensor or float): 타겟 "tensor" 를 구성할 값들이 담겨있는 텐서입니다. 하나의 실수로 선언되면 그 값만으로 채워집니다.
  • reduce (str, optional): 기존의 값을 어떻게 업데이트할 것인지 정의합니다. 'multiply', 'add' 두 가지 방법이 존재하며, 정의되지 않을 경우 기존의 값을 없애고 새로운 값으로 치환합니다.

Example

먼저 간단하게 "src"와 "index"를 정의하고, 마지막 줄의 코드 결과를 살펴보겠습니다.

src = torch.arange(1, 11).reshape((2, 5))
src
# tensor([[ 1,  2,  3,  4,  5],
#         [ 6,  7,  8,  9, 10]])
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
#tensor([[1, 0, 0, 4, 0],
#        [0, 2, 0, 0, 0],
#        [0, 0, 3, 0, 0]])

(3,5) 의 차원을 가진 텐서에 scatter 함수가 적용되었으며 인덱싱 기준은 0 입니다. 위의 3차원 텐서에 대한 업데이트를 2차원에 대해 생각해보면 다음과 같습니다.

tensor[index[i][j]][j] = src[i][j]  # if dim == 0
tensor[i][index[i][j]] = src[i][j]  # if dim == 1

간단히 생각해보면 "src" 의 각 행, 열의 값을 해당 행/열 위치의 "index" 값에 따라 "tensor" 상에서 어떻게 위치시킬 것인지의 문제입니다.

  • index[0][0]의 값은 0 이므로 tensor[index[0][0]][0] = tensor[0][0] 에 src[0][0]의 값인 1이 대입됩니다.
  • index[0][1]의 값은 1 이므로 tensor[index[0][1]][1] = tensor[1][1] 에 src[0][1]의 값인 2가 대입됩니다.
  • index[0][2]의 값은 2 이므로 tensor[index[0][2]][2] = tensor[2][2] 에 src[0][2]의 값인 3이 대입됩니다.
  • index[0][3]의 값은 0 이므로 tensor[index[0][3]][3] = tensor[0][3] 에 src[0][3]의 값인 4가 대입됩니다.
  • "index" 텐서는 (1,4) 크기이므로 "src" 텐서의 첫 번째 줄 값만 "tensor" 에 할당됩니다.

다른 예를 하나 더 보겠습니다. 이번에는 인덱싱 기준 축이 1 입니다.

index = torch.tensor([[0, 1, 2], [0, 1, 4]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
# tensor([[1, 2, 3, 0, 0],
#         [6, 7, 0, 0, 8],
#         [0, 0, 0, 0, 0]])
  • index[0][0]의 값은 0 이므로 tensor[0][index[0][0]] = tensor[0][0] 에 src[0][0]의 값인 1이 대입됩니다.
  • index[0][1]의 값은 1 이므로 tensor[0][index[0][1]] = tensor[0][1] 에 src[0][1]의 값인 2가 대입됩니다.
  • index[0][2]의 값은 2 이므로 tensor[0][index[0][2]] = tensor[0][2] 에 src[0][2]의 값인 3이 대입됩니다.
  • index[1][0]의 값은 0 이므로 tensor[1][index[1][0]] = tensor[1][0] 에 src[1][0]의 값인 6이 대입됩니다.
  • index[1][1]의 값은 1 이므로 tensor[1][index[1][1]] = tensor[1][1] 에 src[1][1]의 값인 7이 대입됩니다.
  • index[1][2]의 값은 4 이므로 tensor[1][index[1][2]] = tensor[1][4] 에 src[1][2]의 값인 8이 대입됩니다.

두 가지의 예로부터 알 수 있는 점은 "index" 텐서의 크기만큼 "src" 텐서의 값이 할당된다는 점입니다. 즉, 두 번째 예처럼 "index" 텐서가 (2,3) 크기이므로 "src" 텐서의 세번째 줄 이상, 네번째 열 이상의 값들은 "tensor" 에 할당될 수 없다는 점입니다. Pytorch 공식문서에도 적혀 있는 조건으로 "index" 와 "src" 는 같은 차원을 가져야 하며 어떠한 차원 $d$에 대해서 $index.size(d) \leq src.size(d)$를 만족해야 합니다. 또한, $d\neq dim$에 대해서 $index.size(d) \leq tensor.size(d)$ 를 만족해야 합니다. ("tensor" 의 인덱싱 가능한 범위를 넘어서기 때문입니다.)

reduce 

"reduce" 파라미터를 정의하지 않으면 위와 같이 대입을 수행하고 'multiply' / 'add' 를 정의하면 기존 값에 곱하기 / 더하기를 수행합니다. 3차원 "tensor" 에 "reduce='multiply'" 일 경우 다음과 같습니다.

tensor[index[i][j][k]][j][k] *= src[i][j][k]  # if dim == 0
tensor[i][index[i][j][k]][k] *= src[i][j][k]  # if dim == 1
tensor[i][j][index[i][j][k]] *= src[i][j][k]  # if dim == 2

다음 예를 살펴보도록 하겠습니다. (2, 4) 크기에 2로 채워진 tensor 에 대해 "dim=1" 로 scatter 함수를 수행합니다.

torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]), 1.23, reduce='multiply')
# tensor([[2.0000, 2.0000, 2.4600, 2.0000],
#         [2.0000, 2.0000, 2.0000, 2.4600]])
  • index[0][0]의 값은 2 이므로 tensor[0][index[0][0]] = tensor[0][2] 에 "src" 로 주어진 1.23 이 기존 값과 곱해집니다.
  • index[0][1]의 값은 3 이므로 tensor[0][index[0][1]] = tensor[0][3] 에 "src" 로 주어진 1.23 이 기존 값과 곱해집니다.

"reduce='add'" 인 경우도 마찬가지이며 'scatter_add_()' 함수를 대신 사용할 수 있습니다. 특히, 다음 예처럼 인덱싱의 결과에 따라 "tensor" 에 같은 위치에 여러 번 적용될 수 있습니다. 

src = torch.ones((2, 5))
index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]])
torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
# tensor([[2., 0., 0., 1., 1.],
#         [0., 2., 0., 0., 0.],
#         [0., 0., 2., 1., 1.]])
  • index[0][0]의 값은 0 이므로 tensor[index[0][0]][0] = tensor[0][0] 에 1이 기존 값과 더해집니다.
  • index[0][1]의 값은 1 이므로 tensor[index[0][1]][1] = tensor[1][1] 에 1이 기존 값과 더해집니다.
  • index[0][2]의 값은 2 이므로 tensor[index[0][2]][2] = tensor[2][2] 에 1이 기존 값과 더해집니다.
  • index[0][3]의 값은 0 이므로 tensor[index[0][3]][3] = tensor[0][3] 에 1이 기존 값과 더해집니다.
  • index[0][4]의 값은 0 이므로 tensor[index[0][4]][4] = tensor[0][4] 에 1이 기존 값과 더해집니다.
  • index[1][0]의 값은 0 이므로 tensor[index[1][0]][0] = tensor[0][0] 에 1이 기존 값과 더해져 2가 됩니다. 나머지 인덱스 값에 대해서도 같은 작업을 반복합니다.

 

참조

 


Pytorch - gather

반응형

'Machine Learning Models > Pytorch' 카테고리의 다른 글

Pytorch - embedding  (0) 2021.06.23
Pytorch - autograd 정의  (1) 2021.06.02
Pytorch - gather  (0) 2021.06.01
Pytorch - backward(retain_graph=True) (1)  (4) 2021.05.09
Pytorch - hook  (5) 2021.05.09