MobileNet V2原理解析与Pytorch代码实现

Table of Contents

正文

1.原理解析

Linear Bottleneck

Inverted Residual Blocks

2.网络结构

3.代码实现

3.1 BottleNeck

3.2 整体结构


论文名:MobileNetV2: Inverted Residuals and Linear Bottlenecks

下载地址:https://arxiv.org/pdf/1801.04381.pdf

正文

MobileNet V2是为了解决MobileNet V1在训练中特征退化的问题而提出的,为了解决这个问题,论文提出了线性瓶颈的倒残差结构the inverted residual with linear bottleneck,它主要包含linear bottleneck和Inverted residuals。

1.原理解析

  MobileNet V2的设计思想都包含在倒残差结构中,它主要包含两个内容linear bottleneckInverted residuals,后面分别对它们进行讲解.

MobileNetV1为什么会出现特征退化呢?论文作者认为这是ReLU导致的,并给出了解释.这块我没太读懂,就简单的说一下自己的理解吧.论文中给出下面这个图表达的意思是:输入数据经过conv处理得到输出数据,输出数据的维度越高,经过ReLU的处理后数据的损失才越小.那么现在有两种解决办法:要么提高输出数据的维度,要么去掉ReLU函数采用线性单元.

Linear Bottleneck

由上述分析可知,为了解决特征退化的最直接的方法就是将ReLU函数换成线性函数(其实就是去掉ReLU函数,不做处理),为了保证深度可分离卷积结构(Depthwise Separable Convolutions,简称DS conv)的非线性拟合能力,只将结构中1*1卷积(pointwise conv)后面的ReLU函数换成线性函数,由此得到了Linear Bottleneck.

Inverted Residual Blocks

特征维度越高,经过ReLU函数处理后信息丢失越少,因此在DS conv结构之前增加一个1*1卷积将feature map的维度提升,从而保证在经过ReLU处理后信息丢失的更少;同时借鉴ResNet的跨层连接shortcut的思想,提高特征的利用率。ResNet的操作是先降维再升维,而本文是先升维再将维,因此叫做倒残差模型Inverted Residual Blocks,其结构图如下所示:

bottleneck

将Linear Bottleneck和Inverted Residual Blocks结合到一起,得到MobileNet V2的结构基元bottleneck,其结构图如下所示(方块的高度即代表通道数):

  当stride=2时,输入输出的维度不同,不进行short cut连接操作。

2.网络结构

  其中t表示扩增的倍数,c表示输出特征图的channel,n表示层的重复次数,s表示stride。

需要注意的是:

1) 当n>1时(即该瓶颈层重复的次数>1),只在第一个瓶颈层stride为对应的s,其他重复的瓶颈层stride均为1

2) 当n>1时,只在第一个瓶颈层特征维度为c,其他时候channel不变。

3.代码实现

3.1 BottleNeck

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#Linear +Inverted Residuals=BottleNeck
class BottleNeck(nn.Modul):
    def __init__(self,in_channels,out_channels,stride,t):
        self.residual=nn.Sequential(

            nn.Conv2d(in_channels,in_channels*t,1,),
            nn.BatchNorm2d(in_channels*t),
            nn.ReLU6(inplace=True),

            nn.Conv2d(in_channels*t,in_channels*t,3,stride=stride,padding=1,groups=in_channels*t),
            nn.BatchNorm2d(in_channels*t),
            nn.ReLU6(inplace=True),

            nn.Conv2d(in_channels*t, out_channels , 1, ),
            nn.BatchNorm2d(out_channels),
        )
        self.in_channels=in_channels
        self.out_channels=out_channels
        self.stride=stride

    def forward(self,x):
        residual=self.residual(x)

        if self.stride==1 and self.in_channels==self.out_channels:
            residual+=x

        return residual

3.2 整体结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class MobileNetV2(nn.Modul):
    def __init__(self,class_num=settings.NUM_CLASSES):
        super().__init__()

        self.conv1=nn.Sequential(
            nn.Conv2d(3,32,3,padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU6(inplace=True)
        )

        self.bottleneck1 = self.make_layer(1, 32, 16, 1, 1)
        self.bottleneck2 = self.make_layer(2, 16, 24, 2, 6)
        self.bottleneck3 = self.make_layer(3, 24, 32, 2, 6)
        self.bottleneck4 = self.make_layer(4, 32, 64, 2, 6)
        self.bottleneck5 = self.make_layer(3, 64, 96, 1, 6)
        self.bottleneck6 = self.make_layer(3, 96, 160, 2, 6)
        self.bottleneck7 = self.make_layer(1, 160, 320, 1, 6)

        self.conv2=nn.Sequential(
            nn.Conv2d(320,1280,1),
            nn.BatchNorm2d(1280),
            nn.ReLU6(inplace=True)
        )

        self.conv3=nn.Conv2d(1280,class_num)

    def make_layer(self,repeat,in_channels,out_channels,stride,t):
        layers=[]

        layers.append(BottleNeck(in_channels,out_channels,stride,t))

        while repeat-1:
            layers.append(BottleNeck(in_channels, out_channels, stride, t))
            repeat-=1

    def forward(self,x):
        x=self.conv1(x)
        x = self.bottleneck1(x)
        x = self.bottleneck2(x)
        x = self.bottleneck3(x)
        x = self.bottleneck4(x)
        x = self.bottleneck5(x)
        x = self.bottleneck6(x)
        x = self.bottleneck7(x)
        x = self.conv2(x)
        x=F.adaptive_avg_pool2d(x,1)
        x=self.conv3(x)
        x=x.view(x,size(0),-1)

        return x