[torch] 기본 함수 이해와 활용
예제 코드
def _gather_feat(feat, ind, mask=None):
dim = feat.size(2)
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
feat = feat.gather(1, ind)
return feat
torch.size()
numpy로 비유하면 shape과 같습니다. 차원을 확인할 때 사용하죠. 그럼 코드 1번째 출을 보면 dim이라는 변수에 들어온 변수 feat의 2번째 차원을 넣어준다는 거죠.
예를 들어 feat.size() 가 ( 1, 3, 4 ) 가 나왔다면, feat.size(2) 는 4 이기 때문에 dim은 4가 됩니다. 여기서 알 수 있는 점은 feat은 최소 2차원 이겠네요.
torch.unsqueeze()
차원을 추가할 때 사용합니다. unsqueeze(2) 라는 말은 2 차원에 새로운 차원을 추가합니다. 만약 ind.size() 가 (3,2, 5) 였다면, ind.unsqueeze(2) 는 (3, 2, 1, 5 ) 가 됩니다.
torch.expand()
차원을 늘릴 떄 사용합니다. ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) 이 코드가 있죠? 이 코드는 ind의 본래 2차원에 새로운 차원을 추가하고 해당 차원을 feat의 2차원으로 늘려준다는 의미가 되겠죠. 왜 그러냐하면, 앞의 두 차원은 본인의 차원을 사용하고, 나머지 차원만 dim으로 늘려줬기 때문이죠. 그럼 여기서 알 수 있는 점은 ind는 원래는 2차원이었지만, 3차원 임을 알 수 있겠네요,
torch.gather()
gather가 조금 복잡하다 생각할 수 있는데, 그냥 원하는 차원을 기준으로 인덱스 하겠다는 함수 입니다. 말이 굉장히 어렵지만, 간단합니다. feat.gather(1, ind) 라는 것을 보았을 때, gather은 long() type일 겁니다. index 역할을 하고 기준은 1차원 입니다. 누구의 1 차원 이냐면, ' feat ' tensor의 1차원을 기준으로 ind 에 해당하는 요소들을 뽑겠다는 말이죠. 그럼 3번 째 코드를 기준으로 보았을 때 알 수 있는 점은, feat과 ind의 1차원은 같으며, feat은 3차원임을 알 수 있겠죠.
총 분석
해당 코드는, feat과 ind를 받습니다. feat은 3차원이고 ind는 2차원 이죠. 해당 함수가 하는 역할은 3차원 tensor를 받는데 1차원을 기준으로 ind 에 해당하는 요소들만 뽑아서 차원을 변형합니다. 만약 feat.size() 이 (1, 2000, 3) 이고 ind.size() 가 (1, 200) 이라면, 함수를 통과하고 난 뒤 결과는 ( 1, 200, 3) 이 되겠죠.
그럼 다음에 뵈요~