pytorch.gather() #torch.gather() #torch.expand #torch.unsqueeze (1) 썸네일형 리스트형 [torch] 기본 함수 이해와 활용 예제 코드 def _gather_feat(feat, ind, mask=None): dim = feat.size(2) ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) feat = feat.gather(1, ind) return feat torch.size() numpy로 비유하면 shape과 같습니다. 차원을 확인할 때 사용하죠. 그럼 코드 1번째 출을 보면 dim이라는 변수에 들어온 변수 feat의 2번째 차원을 넣어준다는 거죠. 예를 들어 feat.size() 가 ( 1, 3, 4 ) 가 나왔다면, feat.size(2) 는 4 이기 때문에 dim은 4가 됩니다. 여기서 알 수 있는 점은 feat은 최소 2차원 이겠네요. torch.unsqu.. 이전 1 다음