AI

[torch] 기본 함수 이해와 활용

개발자_WH 2023. 1. 19. 23:46
728x90
반응형

예제 코드

def _gather_feat(feat, ind, mask=None): 
	dim  = feat.size(2)
        ind  = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
        feat = feat.gather(1, ind)
    
        return feat

torch.size()

  numpy로 비유하면 shape과 같습니다. 차원을 확인할 때 사용하죠. 그럼 코드 1번째 출을 보면 dim이라는 변수에 들어온 변수 feat의 2번째 차원을 넣어준다는 거죠.

 

예를 들어 feat.size() 가 ( 1, 3, 4 ) 가 나왔다면, feat.size(2) 는 4 이기 때문에 dim은 4가 됩니다. 여기서 알 수 있는 점은 feat은 최소 2차원 이겠네요.


torch.unsqueeze()

  차원을 추가할 때 사용합니다. unsqueeze(2) 라는 말은 2 차원에 새로운 차원을 추가합니다. 만약 ind.size() 가 (3,2, 5) 였다면, ind.unsqueeze(2) 는 (3, 2, 1, 5 ) 가 됩니다.


torch.expand()

 차원을 늘릴 떄 사용합니다. ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)  이 코드가 있죠? 이 코드는 ind의 본래 2차원에 새로운 차원을 추가하고 해당 차원을 feat의 2차원으로 늘려준다는 의미가 되겠죠. 왜 그러냐하면, 앞의 두 차원은 본인의 차원을 사용하고, 나머지 차원만 dim으로 늘려줬기 때문이죠. 그럼 여기서 알 수 있는 점은 ind는 원래는 2차원이었지만, 3차원 임을 알 수 있겠네요,


torch.gather()

gather가 조금 복잡하다 생각할 수 있는데, 그냥 원하는 차원을 기준으로 인덱스 하겠다는 함수 입니다. 말이 굉장히 어렵지만, 간단합니다. feat.gather(1, ind) 라는 것을 보았을 때, gather은 long() type일 겁니다. index 역할을 하고 기준은 1차원 입니다. 누구의 1 차원 이냐면, ' feat ' tensor의 1차원을 기준으로 ind 에 해당하는 요소들을 뽑겠다는 말이죠. 그럼 3번 째 코드를 기준으로 보았을 때 알 수 있는 점은,  feat과 ind의 1차원은 같으며, feat은 3차원임을 알 수 있겠죠.


총 분석

  해당 코드는, feat과 ind를 받습니다. feat은 3차원이고 ind는 2차원 이죠. 해당 함수가 하는 역할은 3차원 tensor를 받는데 1차원을 기준으로 ind 에 해당하는 요소들만 뽑아서 차원을 변형합니다. 만약 feat.size() 이 (1, 2000, 3) 이고 ind.size() 가 (1, 200) 이라면, 함수를 통과하고 난 뒤 결과는 ( 1, 200, 3) 이 되겠죠. 

 

그럼 다음에 뵈요~

728x90
반응형