AI

CSPNet 파헤치기 1

개발자_WH 2023. 4. 19. 09:31
728x90
반응형

먼저 우리가 끝장을 봐야할 부분을 가지고 오고 내부에 있는 모든것을 뜯어보겠습니다. 어떤게 목표냐? 아래 코드에요

설명을 보자면 CSP base model이고 논문 링크를 첨부해놨네요. 기존 논문과는 다른 부분이 있는데 1x1 expansion conv를 다룬다네요. 목적은 간단함을 위해서고요. 뭐 보면서 시작해봅시다.

class CSPNet(Backbone):
    """Cross Stage Partial base model.

    Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929

    NOTE: There are differences in the way I handle the 1x1 'expansion' conv in this impl vs the
    darknet impl. I did it this way for simplicity and less special cases.
    """

    def __init__(self, cfg, in_chans=3, output_stride=32, global_pool='avg', drop_rate=0.,
                 act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0.,
                 zero_init_last_bn=True, stage_fn=CrossStage, block_fn=ResBottleneck, out_features=None):
        super().__init__()
        self.drop_rate = drop_rate
        assert output_stride in (8, 16, 32)
        layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer)

        # Construct the stem
        self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args)
        self.feature_info = [stem_feat_info]
        prev_chs = stem_feat_info['num_chs']
        curr_stride = stem_feat_info['reduction']  # reduction does not include pool
        if cfg['stem']['pool']:
            curr_stride *= 2

        # Construct the stages
        per_stage_args = _cfg_to_stage_args(
            cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate)
        self.stages = nn.Sequential()
        out_channels = []
        out_strides = []
        for i, sa in enumerate(per_stage_args):
            self.stages.add_module(
                str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn))
            prev_chs = sa['out_chs']
            curr_stride *= sa['stride']
            self.feature_info += [dict(num_chs=prev_chs,
                                       reduction=curr_stride, module=f'stages.{i}')]
            out_channels.append(prev_chs)
            out_strides.append(curr_stride)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.01)
                nn.init.zeros_(m.bias)
        if zero_init_last_bn:
            for m in self.modules():
                if hasattr(m, 'zero_init_last_bn'):
                    m.zero_init_last_bn()

        # cspdarknet: csp1, csp2, csp3, csp4
        # cspresnet: csp0, csp1, csp2, csp3
        out_features_names = ["csp{}".format(i) for i in range(len(per_stage_args))]
        self._out_feature_strides = dict(zip(out_features_names, out_strides))
        self._out_feature_channels = dict(zip(out_features_names, out_channels))
        if out_features is None:
            self._out_features = out_features_names
        else:
            self._out_features = out_features

    def output_shape(self):
        return {
            name: ShapeSpec(
                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
            )
            for name in self._out_features
        }

    def size_divisibility(self):
        return 32

    def forward(self, x):
        x = self.stem(x)
        outputs = {}
        for i, stage in enumerate(self.stages):
            name = f"csp{i}"
            x = stage(x)
            if name in self._out_features:
                outputs[name] = x
        return outputs

이제 부터 봐야할 CSP에요 참 재미있어요. 이 코드는 SparseInst에 구현되어 있는 cspnet.py의 코듭니다. 끝장을 보자구요

 


1. init

초기화는 어떻게 되는지 한번 볼게요 부분만 따로 가져와 보겠습니다.

def __init__(self, cfg, in_chans=3, output_stride=32, global_pool='avg', drop_rate=0.,
                 act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0.,
                 zero_init_last_bn=True, stage_fn=CrossStage, block_fn=ResBottleneck, out_features=None):
        super().__init__()
        self.drop_rate = drop_rate
        assert output_stride in (8, 16, 32)
        layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer)

        # Construct the stem
        self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args)
        self.feature_info = [stem_feat_info]
        prev_chs = stem_feat_info['num_chs']
        curr_stride = stem_feat_info['reduction']  # reduction does not include pool
        if cfg['stem']['pool']:
            curr_stride *= 2

        # Construct the stages
        per_stage_args = _cfg_to_stage_args(
            cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate)
        self.stages = nn.Sequential()
        out_channels = []
        out_strides = []
        for i, sa in enumerate(per_stage_args):
            self.stages.add_module(
                str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn))
            prev_chs = sa['out_chs']
            curr_stride *= sa['stride']
            self.feature_info += [dict(num_chs=prev_chs,
                                       reduction=curr_stride, module=f'stages.{i}')]
            out_channels.append(prev_chs)
            out_strides.append(curr_stride)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.01)
                nn.init.zeros_(m.bias)
        if zero_init_last_bn:
            for m in self.modules():
                if hasattr(m, 'zero_init_last_bn'):
                    m.zero_init_last_bn()

        # cspdarknet: csp1, csp2, csp3, csp4
        # cspresnet: csp0, csp1, csp2, csp3
        out_features_names = ["csp{}".format(i) for i in range(len(per_stage_args))]
        self._out_feature_strides = dict(zip(out_features_names, out_strides))
        self._out_feature_channels = dict(zip(out_features_names, out_channels))
        if out_features is None:
            self._out_features = out_features_names
        else:
            self._out_features = out_features

class instance를 생성하게 되면 위처럼 초기화 된다는 내용인데 차근히 보죠. 먼저 backbone을 상속받아요.

가지고 있는 변수의 종류를 볼게요.

  • self.drop_rate
  • layer_args
  • self.stem, stem_feat_info
  • self.feature_info
  • prev_chs
  • curr_stride
  • per_stage_args
  • self.stages
  • out_channels
  • out_strides
  • out_features_names
  • self._out_feature_strides
  • self._out_feature_chnnels

이렇게 정의를 하네요.  drop_rate부터 보면 default = 0 이고 instance를 생성할 떄 받는값이네요.

그 아래 보면 단언문이 있는 데 output_stride가 (8,16,32)네요. 이 말은 resolution이 1/32까지 줄어든다는 것을 의미하겠고요.

 

layer_args 는 dictionary를 사용하고 components로는 act_layer, norm_larer, aa_layer가 있네요.

각각은 nn.LeakyReLU, nn.BatchNorm2d, None이 default네요.

 

3번째 변수를 볼까요? self.stem, stem_feat_info 인데 이 친구는 create_stem 함수에 의한 output이네요.

def create_stem(
        in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='',
        act_layer=None, norm_layer=None, aa_layer=None):
    stem = nn.Sequential()
    if not isinstance(out_chs, (tuple, list)):
        out_chs = [out_chs]
    assert len(out_chs)
    in_c = in_chans
    for i, out_c in enumerate(out_chs):
        conv_name = f'conv{i + 1}'
        stem.add_module(conv_name, ConvBnAct(
            in_c, out_c, kernel_size, stride=stride if i == 0 else 1,
            act_layer=act_layer, norm_layer=norm_layer))
        in_c = out_c
        last_conv = conv_name
    if pool:
        if aa_layer is not None:
            stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
            stem.add_module('aa', aa_layer(channels=in_c, stride=2))
        else:
            stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
    return stem, dict(num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv]))
 

cerate_stem 함수에요. in_chans, out_chs, kernel_size, stride, pool, act_layer, norm_layer,aa_layer를 받네요

stem은 nn.Sequential()을 사용하죠.

nn.Sequential 클래스는 nn.Linear, nn.ReLU(활성화 함수) 같은 모듈들을 인수로 받아서 순서대로 정렬해놓고 입력값이 들어모면 순서대로 모듈을 실행해서 결과값을 리턴한답니다. 그럼 저기 들어갈 모듈이 무엇일까네요?

우선 out_chs가 tuple, list가 아니면 받아온 out_chs를 list로 받네요 그리고 0이면 아래코드는 돌아가지 않고요.

받아온 in_chans로 in_c를 초기화하고 for 문을 돕니다.

이 때 enumerate()를 사용하는데 i, out_c 는 각각 index와 out_chs 들 중에서 순서대로의 값을 의미합니다.

conv_name은 formating 을 통해서 conv의 이름을 붙여주고 그 아래서 부터 stem(), nn.Sequential()에 추가해줍니다

add_Module에 바로 위에서 정의한 conv_name과 ConvBnAct ()를 추가해 주는데 in_c는 바뀌지 않고 out_c 만 out_chs에 따라 바뀌게 되겠네요. ConvBnAct()는 이름으로도 유추가능하지만 conv + Bn + Act로 이루어진 친구겠지요?

class ConvBnAct(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
                 bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None,
                 drop_block=None):
        super(ConvBnAct, self).__init__()
        use_aa = aa_layer is not None

        self.conv = create_conv2d(
            in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
            padding=padding, dilation=dilation, groups=groups, bias=bias)

        # NOTE for backwards compatibility with models that use separate norm and act layer definitions
        self.bn = norm_layer(out_channels)
        self.act = act_layer()
        self.aa = aa_layer(
            channels=out_channels) if stride == 2 and use_aa else None

    @property
    def in_channels(self):
        return self.conv.in_channels

    @property
    def out_channels(self):
        return self.conv.out_channels

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        if self.aa is not None:
            x = self.aa(x)
        return x

보기에는 복잡해보이지만 별거 아니에요. self.conv, self.bn, self.act가 있고 self.conv는 timm.models.layers에 있는 create_conv2d를 활용했을 뿐이네요.

계속 보자면 create_conv2d 에들어갈 in_channels, out_channels, kernel_size, strid, padding, dilation, groups는 다 받아오는 인수네요. self.bn의 기본은 nn.BatchNrom2d이고 act는 nn.ReLU 네요. property 데코레이터로 감싸서 in_chnnels와 out_channels에 접근할 수 있게 해주고 forward는 그냥 conv -> bn -> act를 실행하네요.

for i, out_c in enumerate(out_chs):
        conv_name = f'conv{i + 1}'
        stem.add_module(conv_name, ConvBnAct(
            in_c, out_c, kernel_size, stride=stride if i == 0 else 1,
            act_layer=act_layer, norm_layer=norm_layer))
        in_c = out_c
        last_conv = conv_name

다시 돌아가서, in_c = 3, out_c는 out_chs에 따라 바뀌고 네요. in_c = out_c를 통해 layer channel을 먼트롤하고 last_conv 에 conv_name을 주네

if pool:
        if aa_layer is not None:
            stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
            stem.add_module('aa', aa_layer(channels=in_c, stride=2))
        else:
            stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

pool 을 하는데 aa_layer가 있으면 nn.Maxpool2D 를 주고 aa_layer 역시 추가해주고 없으면 Maxpool2D layer만 추가하는데 stride 가 2 네요.

return stem, dict(num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv]))

그리고 stem과 dict를 return 하는데 dict는 in_c 정보, stride 정보 module 이름 정보를 가지고 있네요. 이름에서 알 수 있다 시피 channels의 수는 in_c를 통해 줄어드는 정보는 stride를 통해 알 수 있겠네요.

 

self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args)

다시 돌아가볼까요? 우리는 이제 3번째 변수를 초기화 하는 부분까지 봤어요. 다음으로 넘어가 봐야겠죠?

자 다시 봐야할 코드먼저 보여드릴게요

self.feature_info = [stem_feat_info]
prev_chs = stem_feat_info['num_chs']
curr_stride = stem_feat_info['reduction']  

stem_feat_info 는 dict 였는데 self.feature_info에 list로 넣어주고 해당 부분에서 key를 입력해서 prev_chs와 curr_stride의 정보를 init하네요. 그 아래로  가볼게요.

if cfg['stem']['pool']:
            curr_stride *= 2

cfg파일을 받는데 cfg['stem']['pool'] 이 있으면 현재 stirde를 2배해주고요. 잠깐 상기하고 넘어가면 stride는 resolution reduction과 관련이 있었죠. 여튼 계속 봅시다

per_stage_args = _cfg_to_stage_args(
            cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate)

per_stage_args 는 _cfg_to_stage_args 라는 함수의 return 값이네요 해당 함수를 볼게요.

def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.):
    # get per stage args for stage and containing blocks, calculate strides to meet target output_stride
    num_stages = len(cfg['depth'])
    if 'groups' not in cfg:
        cfg['groups'] = (1,) * num_stages
    if 'down_growth' in cfg and not isinstance(cfg['down_growth'], (list, tuple)):
        cfg['down_growth'] = (cfg['down_growth'],) * num_stages
    if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)):
        cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages
    cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \
        [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])]
    stage_strides = []
    stage_dilations = []
    stage_first_dilations = []
    dilation = 1
    for cfg_stride in cfg['stride']:
        stage_first_dilations.append(dilation)
        if curr_stride >= output_stride:
            dilation *= cfg_stride
            stride = 1
        else:
            stride = cfg_stride
            curr_stride *= stride
        stage_strides.append(stride)
        stage_dilations.append(dilation)
    cfg['stride'] = stage_strides
    cfg['dilation'] = stage_dilations
    cfg['first_dilation'] = stage_first_dilations
    stage_args = [dict(zip(cfg.keys(), values)) for values in zip(*cfg.values())]
    return stage_args

cfg, curr_stride, output_stride, drop_path_rate 를 받아서 시작하는 함수입니다. 

num_stages = len( cfg['depth'] ) 이네요.

groups가 cfg에 없다면 (1,) * num_stages 가 cfg['group'] 이고

down_growth, cross_linear가 있느냐에 따라 num_stages가 곱해지는 군요

cfg['block_dpr'] 에 if 문이 들어가 있는데 drop_path_rate이 없으면 [None]*num_stages를 넣어주고 있으면 torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])) 를 생성해주네요. 음 의미로 보면 0 부터 drop_path_rate까지 depth를 모두더한 다음 해당 값별로 나눠준 만큼의 간격으로 생성해줍니다.

그 다음은 stage_strides, stage_dilations, stage_first_dilations 를 list로 초기화하고 dilation을 1로 초기화 한 뒤 for 문을 돌며, dilation에 값을 넣어주는데 받아온 output_stride에 따른 조건 문에 따라 cfg_stride를 곱해 dilation을 업데이트해주고 stride는 1로 바꿔주거나, stride를 cfg_stride로 update하고 stride를 곱해주네요. 그렇게 처리된 stride되 dilation은 위에 정의한 변수에 추가해주고요. for 문이 끝난 뒤에는 cfg의 내용을 업데이트해주고 stage_args를 list 속 dictionary 형태로 업데이트 해주는데 key,value로 묶인 쌍으로 업데이트 해주네요.

 

다시 돌아가

self.stages = nn.Sequential() 을 통해 만들어주고, out_channels와 out_strides 를 빈 list로 만들어주죠

for i, sa in enumerate(per_stage_args):
            self.stages.add_module(
                str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn))
            prev_chs = sa['out_chs']
            curr_stride *= sa['stride']
            self.feature_info += [dict(num_chs=prev_chs,
                                       reduction=curr_stride, module=f'stages.{i}')]
            out_channels.append(prev_chs)
            out_strides.append(curr_stride)

해당 for 문에선느 per_stages_args를 묶고 추가해 주는데요.

self.stages.add_module 에 인덱스를 위해 enumerate()로 값을 받는데 stage_fn ( )에, prev_chs, **sa, **layer_args, block_fn을 주네요. 여기서 stage_fn 은 class의 instance 였죠?

class CrossStage(nn.Module):
    """Cross Stage."""

    def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1.,
                 groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None,
                 block_fn=ResBottleneck, **block_kwargs):
        super(CrossStage, self).__init__()
        first_dilation = first_dilation or dilation
        down_chs = out_chs if down_growth else in_chs  # grow downsample channels to output channels
        exp_chs = int(round(out_chs * exp_ratio))
        block_out_chs = int(round(out_chs * block_ratio))
        conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'),
                           norm_layer=block_kwargs.get('norm_layer'))

        if stride != 1 or first_dilation != dilation:
            self.conv_down = ConvBnAct(
                in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
                aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs)
            prev_chs = down_chs
        else:
            self.conv_down = None
            prev_chs = in_chs

        # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also,
        # there is also special case for the first stage for some of the model that results in uneven split
        # across the two paths. I did it this way for simplicity for now.
        self.conv_exp = ConvBnAct(prev_chs, exp_chs, kernel_size=1,
                                  apply_act=not cross_linear, **conv_kwargs)
        prev_chs = exp_chs // 2  # output of conv_exp is always split in two

        self.blocks = nn.Sequential()
        for i in range(depth):
            drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None
            self.blocks.add_module(str(i), block_fn(
                prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs))
            prev_chs = block_out_chs

        # transition convs
        self.conv_transition_b = ConvBnAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs)
        self.conv_transition = ConvBnAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)

    def forward(self, x):
        if self.conv_down is not None:
            x = self.conv_down(x)
        x = self.conv_exp(x)
        split = x.shape[1] // 2
        xs, xb = x[:, :split], x[:, split:]
        xb = self.blocks(xb)
        xb = self.conv_transition_b(xb).contiguous()
        out = self.conv_transition(torch.cat([xs, xb], dim=1))
        return out

음... 설명하다보니 일일히 다보고 있고 이걸 다 설명하려니까 조금은 그렇지만.. 핵심만 짚어볼게요. 다른 부분들은 크지 않고 핵심은 prev_chs = exp_chs // 2 이부분이에요. 이 부분이 channel을 반으로 쪼개 cross하는 부분의 핵심이고 그 뒤에 self.conv_transition_b, self.conv_transition을 정의한다는 것이 init 부분의 기본이고 forward 함수가 중요하겠네요.

x 가 들어가고 반으로 나뉘어진 다음 blocks를 걸친 절반의 b가 trainsition을 한번 거치고난뒤에 concatenate 되고 다시한번 transition을 통과하게 되네요. 해당 부분이 논문에 나왔던 그림 중 아래의 그림을 구현했다고 보면 되겠네요.

다시 돌아가서 이렇게 module을 만들고 prev_chs와 curr_stride를 업데이트 한 이후에 해당 info를 추가해주고 out_channels와 out_strides를 업데이트하네요.

그 다음 아래에서는 nn.Conv2d와, nn,BatchNorm2d, nn.Linear를 초기화 해주는 과정이고, 마지막으로 per_stageargs를 통해 out_features_names를 넣어주고 _out_feature_strides와 _out_feature_channels를 정의해주면서 초기화는 끝납니다.

 

init하나 쓰는데 이렇게 오래 걸리다니..ㅎㅎ 여튼 여기까지가 첫번째 class가 call 되면 init해주는 부분에 대한 설명이었습니다.

728x90
반응형