Pytorch 实现 DeepLab V3+ 模型,并从 TensorFlow 直接转化预训练参数

????????本文将使用 Pytorch 实现 DeepLab V3+ 语义分割模型,同时,为了不耗费时间训练却能使用已有的预训练参数,将从 DeepLab V3+ 模型的 官方 TensorFlow 实现 转化预训练参数。本文深受 DeepLab V3+ 官方实现的影响,很大程度上可以看成是它的 Pytorch 版翻译。

????????本文的 DeepLab V3+ 模型用的特征提取器(即后端)是 Xception 模型(见上一篇文章Pytorch 实现 Xception 模型,并从 TensorFlow 直接转化预训练参数,有部分修改,加入了保存中间特征的 end_points,便于 DeepLab V3+ 模型调用)。

????????所有代码见 Github: deeplabv3_plus。Pytorch 版本需要 1.1.0 及 之后的版本,因为训练时需要用到 Pytorch 自带的 tensorboard 调用类。

一、实现细节

????????DeepLab V3+ 的原理将在后续的文章详细介绍,这里略过。实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 16 14:24:03 2019

@author: shirhe-lyh


Implementation of DeepLabV3+:
    Encoder-Decoder with atrous separable convolution for semantic image
    segmentation, Liang-Chieh Chen, YuKun Zhu, George Papandreou, Florian
    Schroff, Hartwig Adam, arxiv:1802.02611 (https://arxiv.org/abs/1802.02611).
   
Official implementation:
    https://github.com/tensorflow/models/tree/master/research/deeplab
"""

import os
import torch

import common
import core.feature_extractor as extractor


_BATCH_NORM_PARAMS = {
    'momentum': 0.9997,
    'eps': 1e-5,
    'affine': True,
}


class DeepLab(torch.nn.Module):
    """Implementation of DeepLab V3+."""
   
    def __init__(self, feature_extractor, model_options):
        """Constructor.
       
        Args:
            feature_extractor: The backbone of DeepLab model.
            model_options: A ModelOptions instance to configure models.
        """
        super(DeepLab, self).__init__()
        self._model_options = model_options
       
        # Feature extractor
        self._feature_extractor = feature_extractor

        # Atrous spatial pyramid pooling
        self._aspp = AtrousSpatialPyramidPooling(
            in_channels=feature_extractor.out_channels,
            crop_size=model_options.crop_size,
            output_stride=model_options.output_stride,
            atrous_rates=model_options.atrous_rates,
            use_bounded_activation=model_options.use_bounded_activation,
            add_image_level_feature=model_options.add_image_level_feature,
            image_pooling_stride=model_options.image_pooling_stride,
            image_pooling_crop_size=model_options.image_pooling_crop_size,
            aspp_with_separable_conv=model_options.aspp_with_separable_conv)
       
        # Refine by decoder
        self._refine_decoder = None
        if model_options.decoder_output_stride:
            self._refine_decoder = RefineDecoder(
                feature_extractor=self._feature_extractor,
                crop_size=model_options.crop_size,
                decoder_output_stride=model_options.decoder_output_stride[0],
                decoder_use_separable_conv=model_options.decoder_use_separable_conv,
                model_variant=model_options.model_variant,
                use_bounded_activation=model_options.use_bounded_activation)
           
        # Branch logits
        num_classes = model_options.outputs_to_num_classes[common.OUTPUT_TYPE]
        self._logits_layer = torch.nn.Conv2d(
            in_channels=256, out_channels=num_classes,
            kernel_size=model_options.logits_kernel_size)
       
    def forward(self, x):
        features = self._feature_extractor(x)
        features = self._aspp(features)
        if self._refine_decoder is not None:
            features = self._refine_decoder(features)
        logits = self._logits_layer(features)
        return logits
       
       
class AtrousSpatialPyramidPooling(torch.nn.Module):
    """Atrous Spatial Pyramid Pooling."""
   
    def __init__(self,
                 in_channels,
                 out_channels=256,
                 output_stride=16,
                 crop_size=[513, 513],
                 atrous_rates=[12, 24, 36],
                 use_bounded_activation=False,
                 add_image_level_feature=True,
                 image_pooling_stride=[1, 1],
                 image_pooling_crop_size=None,
                 aspp_with_separable_conv=True):
        """Constructor.
       
        Args:
            in_channels: Number of input filters.
            out_channels: Number of filters in the 1x1 pointwise convolution.
            atrous_rates: A list of atrous convolution rates for ASPP.
            use_bounded_activation: Whether or not to use bounded activations.
            crop_size: A tuple [crop_height, crop_width].
            image_pooling_crop_size: Image pooling crop size [height, width]
                used in the ASPP module.
        """
        super(AtrousSpatialPyramidPooling, self).__init__()
        activation_fn = (
            torch.nn.ReLU6(inplace=False) if use_bounded_activation else
            torch.nn.ReLU(inplace=False))
       
        depth = out_channels
        branches = []
        if add_image_level_feature:
            layers = []
            if crop_size is not None:
                # If image_pooling_crop_size is not specified, use crop_size
                if image_pooling_crop_size is None:
                    image_pooling_crop_size = crop_size
                pool_height = scale_dimension(image_pooling_crop_size[0],
                                              1. / output_stride)
                pool_width = scale_dimension(image_pooling_crop_size[1],
                                             1. / output_stride)
                layers += [torch.nn.AvgPool2d(
                    (pool_height, pool_width), image_pooling_stride)]
                resize_height = scale_dimension(
                    crop_size[0], 1. / output_stride)
                resize_width = scale_dimension(
                    crop_size[1], 1. / output_stride)
            else:
                # If crop_size is None, we simply do global pooling
                layers += [torch.nn.AdaptiveAvgPool2d(output_size=(1, 1))]
                resize_height, resize_width = None, None
            self._resize_height = resize_height
            self._resize_width = resize_width
            layers += [torch.nn.Conv2d(in_channels, depth, 1, bias=False),
                       torch.nn.BatchNorm2d(depth, **_BATCH_NORM_PARAMS),
                       activation_fn]
            branches.append(torch.nn.Sequential(*layers))
       
        # Employ a 1x1 convolution.
        branches.append(torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, depth, kernel_size=1, bias=False),
            torch.nn.BatchNorm2d(depth, **_BATCH_NORM_PARAMS),
            activation_fn))
       
        if atrous_rates:
            # Employ 3x3 convolutions with different atrous rates.
            for i, rate in enumerate(atrous_rates):
                layers = []
                if aspp_with_separable_conv:
                    layers += [
                        SplitSeparableConv2d(in_channels, depth, rate=rate,
                                             activation_fn=activation_fn)]
                else:
                    layers += [
                        torch.nn.Conv2d(in_channels, depth, kernel_size=3,
                                        rate=rate, padding=1, bias=False),
                        torch.nn.BatchNorm2d(depth, **_BATCH_NORM_PARAMS),
                        activation_fn]
                branches.append(torch.nn.Sequential(*layers))
        self._branches = torch.nn.Sequential(*branches)
       
        # Merge branch logits
        self._conv_concat = torch.nn.Sequential(
            torch.nn.Conv2d(len(self._branches) * depth,
                            depth, kernel_size=1, bias=False),
            torch.nn.BatchNorm2d(depth, **_BATCH_NORM_PARAMS),
            activation_fn)
        self._dropout = torch.nn.Dropout2d(p=0.9, inplace=True)
   
    def forward(self, x):
        branch_logits = []
        conv_branches = self._branches
        if len(self._branches) > 4:
            conv_branches = self._branches[1:]
            image_feature = self._branches[0](x)
            if self._resize_height is None:
                _, _, height, width = x.size()
                self._resize_height, self._resize_width = height, width
            image_feature = resize_bilinear(
                image_feature, (self._resize_height, self._resize_width))
            branch_logits.append(image_feature)
        branch_logits += [branch(x) for branch in conv_branches]
        cancat_logits = torch.cat(branch_logits, dim=1)
        cancat_logits = self._conv_concat(cancat_logits)
        cancat_logits = self._dropout(cancat_logits)
        return cancat_logits
   
   
class SplitSeparableConv2d(torch.nn.Module):
    """Splits a seperable conv2d into depthwise and pointwise conv2d."""
   
    def __init__(self, in_channels, out_channels, kernel_size=3, rate=1,
                 use_batch_norm=True, activation_fn=None):
        """Constructor.
       
        Args:
            in_channels: Number of input filters.
            out_channels: Number of filters in the 1x1 pointwise convolution.
            kernel_size: A list of length 2: [kernel_height, kernel_width]
                of the filters. Can be an int if both values are the same.
            rate: Atrous convolution rate for the depthwise convolution.
            with_batch_norm: Whether or not to use batch normalization.
            activation_fn: The activation function to be applied.
        """
        super(SplitSeparableConv2d, self).__init__()
        # For the shape of output of Conv2d, see details at:
        # https://pytorch.org/docs/stable/nn.html#convolution-layers
        # Here, we assume that floor(padding) = padding
        padding = (kernel_size - 1) * rate // 2
        self._conv_depthwise = torch.nn.Conv2d(in_channels,
                                               in_channels,
                                               kernel_size=kernel_size,
                                               dilation=rate,
                                               padding=padding,
                                               groups=in_channels,
                                               bias=False)
        self._conv_pointwise = torch.nn.Conv2d(in_channels,
                                               out_channels,
                                               kernel_size=1,
                                               bias=False)
        self._batch_norm_depthwise = None
        self._batch_norm_pointwise = None
        if use_batch_norm:
            self._batch_norm_depthwise = torch.nn.BatchNorm2d(
                num_features=in_channels, **_BATCH_NORM_PARAMS)
            self._batch_norm_pointwise = torch.nn.BatchNorm2d(
                num_features=out_channels, **_BATCH_NORM_PARAMS)
        self._activation_fn = activation_fn
       
    def forward(self, x):
        x = self._conv_depthwise(x)
        if self._batch_norm_depthwise is not None:
            x = self._batch_norm_depthwise(x)
        if self._activation_fn is not None:
            x = self._activation_fn(x)
        x = self._conv_pointwise(x)
        if self._batch_norm_pointwise is not None:
            x = self._batch_norm_pointwise(x)
        if self._activation_fn is not None:
            x = self._activation_fn(x)
        return x
   
   
class RefineDecoder(torch.nn.Module):
    """Adds the decoder to obtain sharper segmentation results."""
   
    def __init__(self, feature_extractor, aspp_channels=256, crop_size=None,
                 decoder_output_stride=None, decoder_use_separable_conv=False,
                 model_variant=None, use_bounded_activation=False):
        """Constructor.
       
        Args:
            feature_extractor: The backbone of the DeepLab model.
            aspp_channels: The out channels of ASPP.
            crop_size: A tuple [crop_height, crop_width] specifying whole
                patch crop size.
            decoder_output_stride: An integer specifying the output stride of
                low-level features used in the decoder module.
            decoder_use_separable_conv: Employ separable convolution for
                decoder or not.
            model_variant: Model variant for feature extractor.
            use_bounded_activation: Whether or not to use bounded activations.
                Bounded activations better lend themselves to quantized
                inference.
           
        Raises:
            ValueError: If crop_size is None.
        """
        super(RefineDecoder, self).__init__()
       
        if crop_size is None:
            raise ValueError('crop_size must be provided when using decoder.')
           
        self._crop_size = crop_size
        self._output_stride = decoder_output_stride
        activation_fn = (
            torch.nn.ReLU6(inplace=False) if use_bounded_activation else
            torch.nn.ReLU(inplace=False))
        self._extractor = feature_extractor
        feature_names = extractor.networks_to_feature_maps[
            model_variant][extractor.DECODER_END_POINTS][decoder_output_stride]
        self._feature_name = '{}/{}'.format(model_variant, feature_names[0])
        self._decoder = torch.nn.Sequential(
            torch.nn.Conv2d(extractor.feature_out_channels_map[model_variant],
                            out_channels=48, kernel_size=1, bias=False),
            torch.nn.BatchNorm2d(num_features=48, **_BATCH_NORM_PARAMS),
            activation_fn)
        concat_layers = []
        decoder_depth = 256
        if decoder_use_separable_conv:
            concat_layers += [
                SplitSeparableConv2d(in_channels=aspp_channels + 48,
                                     out_channels=decoder_depth, rate=1,
                                     activation_fn=activation_fn),
                SplitSeparableConv2d(in_channels=decoder_depth,
                                     out_channels=decoder_depth,
                                     rate=1, activation_fn=activation_fn)]
        else:
            concat_layers += [
                torch.nn.Conv2d(in_channels=aspp_channels + 48,
                                out_channels=decoder_depth, kernel_size=3,
                                padding=1, bias=False),
                torch.nn.BatchNorm2d(num_features=decoder_depth,
                                     **_BATCH_NORM_PARAMS),
                activation_fn,
                torch.nn.Conv2d(in_channels=decoder_depth,
                                out_channels=decoder_depth,
                                kernel_size=3, padding=1, bias=False),
                torch.nn.BatchNorm2d(num_features=decoder_depth,
                                     **_BATCH_NORM_PARAMS),
                activation_fn]
        self._concat_layers = torch.nn.Sequential(*concat_layers)
       
    def forward(self, x):
        decoder_features_list = [x]
        inter_feature = self._extractor.end_points()[self._feature_name]
        decoder_features_list.append(self._decoder(inter_feature))
        # Determine the output size
        decoder_height = scale_dimension(self._crop_size[0],
                                         1.0 / self._output_stride)
        decoder_width = scale_dimension(self._crop_size[1],
                                        1.0 / self._output_stride)
        # Resize to decoder_height/decoder_width
        for j, feature in enumerate(decoder_features_list):
            decoder_features_list[j] = resize_bilinear(
                feature, (decoder_height, decoder_width))
        x = self._concat_layers(torch.cat(decoder_features_list, dim=1))
        return x
   

def scale_dimension(dim, scale):
    """Scales the input dimension.
   
    Args:
        dim: Input dimension (a scalar).
        scale: The amount of scaling applied to the input.
       
    Returns:
        scaled dimension.
    """
    return int((float(dim) - 1.0) * scale + 1.0)


def resize_bilinear(images, size):
    """Returns resized images.
   
    Args:
        images: A tensor of size [batch, height_in, width_in, channels].
        size: A tuple (height, width).
       
    Returns:
        A tensor of shape [batch, height_out, height_width, channels].
    """
    return torch.nn.functional.interpolate(
        images, size, mode='bilinear', align_corners=True)
   
   
def deeplab(num_classes, crop_size=[513, 513], atrous_rates=[12, 24, 36],
            output_stride=16, pretrained=True, pretained_num_classes=21,
            checkpoint_path='./pretrained_models/deeplabv3_pascal_trainval.pth'):
    """DeepLab v3+ for semantic segmentation."""
    outputs_to_num_classes = {'semantic': num_classes}
    model_options = common.ModelOptions(outputs_to_num_classes,
                                        crop_size=crop_size,
                                        atrous_rates=atrous_rates,
                                        output_stride=output_stride)
    feature_extractor = extractor.feature_extractor(
        model_options.model_variant, pretrained=False,
        output_stride=model_options.output_stride)
    model = DeepLab(feature_extractor, model_options)
   
    if pretrained:
        _load_state_dict(model, num_classes, pretained_num_classes,
                         checkpoint_path)
    return model


def _load_state_dict(model, num_classes, pretained_num_classes,
                     checkpoint_path):
    """Load pretrained weights."""
    if os.path.exists(checkpoint_path):
        state_dict = torch.load(checkpoint_path)
        if num_classes is None or num_classes != pretained_num_classes:
            state_dict.pop('_logits_layer.weight')
            state_dict.pop('_logits_layer.bias')
        model.load_state_dict(state_dict, strict=False)
        print('Load pretrained weights successfully.')
    else:
        raise ValueError('`checkpoint_path` does not exist.')

二、预训练参数转化

????????将 TensorFlow 官方的预训练模型参数转化成 Pytorch 版,这里采用完全的手动指定的方式(Pytorch 的缺陷——不能指定变量名称导致),即对于每一个参数,从 TensorFlow 参数里取出对应的参数值,然后赋值给它。详细如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# -*- coding: utf-8 -*-
"""
Created on Wed Oct  9 17:46:13 2019

@author: shirhe-lyh


Convert tensorflow weights to pytorch weights for DeepLab V3+ models.

Reference:
    https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/
        tf_to_pytorch/convert_tf_to_pt/load_tf_weights.py
"""

import numpy as np
import tensorflow as tf
import torch


_BLOCK_UNIT_COUNT_MAP = {
    'xception_41': [[3, 1], [1, 8], [2, 1]],
    'xception_65': [[3, 1], [1, 16], [2, 1]],
    'xception_71': [[5, 1], [1, 16], [2, 1]],
}


def load_param(checkpoint_path, conversion_map, model_name):
    """Load parameters according to conversion_map.
   
    Args:
        checkpoint_path: Path to tensorflow's checkpoint file.
        conversion_map: A dictionary with format
            {pytorch tensor in a model: checkpoint variable name}
        model_name: The name of Xception model, only supports 'xception_41',
            'xception_65', or 'xception_71'.
    """
    for pth_param, tf_param_name in conversion_map.items():
        param_name_strs =  tf_param_name.split('_')
        if len(param_name_strs) > 1 and param_name_strs[1].startswith('flow'):
            tf_param_name = str(model_name) + '/' + tf_param_name
        tf_param = tf.train.load_variable(checkpoint_path, tf_param_name)
        if 'conv' in tf_param_name and 'weights' in tf_param_name:
            tf_param = np.transpose(tf_param, (3, 2, 0, 1))
            if 'depthwise' in tf_param_name:
                tf_param = np.transpose(tf_param, (1, 0, 2, 3))
        elif 'depthwise_weights' in tf_param_name:
            tf_param = np.transpose(tf_param, (3, 2, 0, 1))
            tf_param = np.transpose(tf_param, (1, 0, 2, 3))
        elif tf_param_name.endswith('weights'):
            tf_param = np.transpose(tf_param)
        assert pth_param.size() == tf_param.shape, ('Dimension mismatch: ' +
            '{} vs {}; {}'.format(pth_param.size(), tf_param.shape,
                 tf_param_name))
        pth_param.data = torch.from_numpy(tf_param)


def convert(model, checkpoint_path):
    """Load Pytorch Xception from TensorFlow checkpoint file.
   
    Args:
        model: The pytorch Xception model, only supports 'xception_41',
            'xception_65', or 'xception_71'.
        checkpoint_path: Path to tensorflow's checkpoint file.
       
    Returns:
        The pytorch Xception model with pretrained parameters.
    """
    block_unit_counts = _BLOCK_UNIT_COUNT_MAP.get(
        model._feature_extractor.scope, None)
    if block_unit_counts is None:
        raise ValueError('Unsupported Xception model name.')
    flow_names = []
    block_indices = []
    unit_indices = []
    flow_names_unique = ['entry_flow', 'middle_flow', 'exit_flow']
    for i, [block_count, unit_count] in enumerate(block_unit_counts):
        flow_names += [flow_names_unique[i]] * (block_count * unit_count)
        for i in range(block_count):
            block_indices += [i + 1] * unit_count
            unit_indices += [j + 1 for j in range(unit_count)]
   
    conversion_map = {}
    # Feature extractor: Root block
    conversion_map_for_root_block = {
        model._feature_extractor._layers[0]._conv.weight:
            'entry_flow/conv1_1/weights',
        model._feature_extractor._layers[0]._batch_norm.bias:
            'entry_flow/conv1_1/BatchNorm/beta',
        model._feature_extractor._layers[0]._batch_norm.weight:
            'entry_flow/conv1_1/BatchNorm/gamma',
        model._feature_extractor._layers[0]._batch_norm.running_mean:
            'entry_flow/conv1_1/BatchNorm/moving_mean',
        model._feature_extractor._layers[0]._batch_norm.running_var:
            'entry_flow/conv1_1/BatchNorm/moving_variance',
        model._feature_extractor._layers[1]._conv.weight:
            'entry_flow/conv1_2/weights',
        model._feature_extractor._layers[1]._batch_norm.bias:
            'entry_flow/conv1_2/BatchNorm/beta',
        model._feature_extractor._layers[1]._batch_norm.weight:
            'entry_flow/conv1_2/BatchNorm/gamma',
        model._feature_extractor._layers[1]._batch_norm.running_mean:
            'entry_flow/conv1_2/BatchNorm/moving_mean',
        model._feature_extractor._layers[1]._batch_norm.running_var:
            'entry_flow/conv1_2/BatchNorm/moving_variance',
    }
    conversion_map.update(conversion_map_for_root_block)
   
    # Feature extractor: Xception block
    for i in range(len(model._feature_extractor._layers[2]._blocks)):
        block = model._feature_extractor._layers[2]._blocks[i]
        ind = [1, 3, 5]
        if len(block._separable_conv_block) < 6:
            ind = [0, 1, 2]
        for j in range(3):
            conversion_map_for_separable_block = {
                block._separable_conv_block[ind[j]]._conv_depthwise.weight:
                    ('{}/block{}/unit_{}/xception_module/' +
                     'separable_conv{}_depthwise/depthwise_weights').format(
                        flow_names[i], block_indices[i], unit_indices[i], j+1),
                block._separable_conv_block[ind[j]]._conv_pointwise.weight:
                    ('{}/block{}/unit_{}/xception_module/' +
                     'separable_conv{}_pointwise/weights').format(
                        flow_names[i], block_indices[i], unit_indices[i], j+1),
                block._separable_conv_block[ind[j]]._batch_norm_depthwise.bias:
                    ('{}/block{}/unit_{}/xception_module/' +
                     'separable_conv{}_depthwise/BatchNorm/beta').format(
                        flow_names[i], block_indices[i], unit_indices[i], j+1),
                block._separable_conv_block[ind[j]]._batch_norm_depthwise.weight:
                    ('{}/block{}/unit_{}/xception_module/' +
                     'separable_conv{}_depthwise/BatchNorm/gamma').format(
                        flow_names[i], block_indices[i], unit_indices[i], j+1),
                block._separable_conv_block[ind[j]]._batch_norm_depthwise.running_mean:
                    ('{}/block{}/unit_{}/xception_module/' +
                     'separable_conv{}_depthwise/BatchNorm/moving_mean').format(
                        flow_names[i], block_indices[i], unit_indices[i], j+1),
                block._separable_conv_block[ind[j]]._batch_norm_depthwise.running_var:
                    ('{}/block{}/unit_{}/xception_module/' +
                     'separable_conv{}_depthwise/BatchNorm/moving_variance').format(
                        flow_names[i], block_indices[i], unit_indices[i], j+1),
                block._separable_conv_block[ind[j]]._batch_norm_pointwise.bias:
                    ('{}/block{}/unit_{}/xception_module/' +
                     'separable_conv{}_pointwise/BatchNorm/beta').format(
                        flow_names[i], block_indices[i], unit_indices[i], j+1),
                block._separable_conv_block[ind[j]]._batch_norm_pointwise.weight:
                    ('{}/block{}/unit_{}/xception_module/' +
                     'separable_conv{}_pointwise/BatchNorm/gamma').format(
                        flow_names[i], block_indices[i], unit_indices[i], j+1),
                block._separable_conv_block[ind[j]]._batch_norm_pointwise.running_mean:
                    ('{}/block{}/unit_{}/xception_module/' +
                     'separable_conv{}_pointwise/BatchNorm/moving_mean').format(
                        flow_names[i], block_indices[i], unit_indices[i], j+1),
                block._separable_conv_block[ind[j]]._batch_norm_pointwise.running_var:
                    ('{}/block{}/unit_{}/xception_module/' +
                     'separable_conv{}_pointwise/BatchNorm/moving_variance').format(
                        flow_names[i], block_indices[i], unit_indices[i], j+1),
            }
            conversion_map.update(conversion_map_for_separable_block)
           
            if getattr(block, '_conv_skip_connection', None) is not None:
                conversion_map_for_shortcut = {
                    block._conv_skip_connection.weight:
                       ('{}/block{}/unit_{}/xception_module/shortcut/' +
                        'weights').format(
                            flow_names[i], block_indices[i], unit_indices[i]),
                    block._batch_norm_shortcut.bias:
                        ('{}/block{}/unit_{}/xception_module/shortcut/' +
                         'BatchNorm/beta').format(
                            flow_names[i], block_indices[i], unit_indices[i]),
                    block._batch_norm_shortcut.weight:
                        ('{}/block{}/unit_{}/xception_module/shortcut/' +
                         'BatchNorm/gamma').format(
                            flow_names[i], block_indices[i], unit_indices[i]),
                    block._batch_norm_shortcut.running_mean:
                        ('{}/block{}/unit_{}/xception_module/shortcut/' +
                         'BatchNorm/moving_mean').format(
                            flow_names[i], block_indices[i], unit_indices[i]),
                    block._batch_norm_shortcut.running_var:
                        ('{}/block{}/unit_{}/xception_module/shortcut/' +
                         'BatchNorm/moving_variance').format(
                            flow_names[i], block_indices[i], unit_indices[i]),
                }
                conversion_map.update(conversion_map_for_shortcut)
       
    # Atrous Spatial Pyramid Pooling: Image feature
    branches = model._aspp._branches
    conversion_map_for_aspp_image_feature = {
        branches[0][1].weight: 'image_pooling/weights',
        branches[0][2].bias: 'image_pooling/BatchNorm/beta',
        branches[0][2].weight: 'image_pooling/BatchNorm/gamma',
        branches[0][2].running_mean: 'image_pooling/BatchNorm/moving_mean',
        branches[0][2].running_var: 'image_pooling/BatchNorm/moving_variance',
        branches[1][0].weight: 'aspp0/weights',
        branches[1][1].bias: 'aspp0/BatchNorm/beta',
        branches[1][1].weight: 'aspp0/BatchNorm/gamma',
        branches[1][1].running_mean: 'aspp0/BatchNorm/moving_mean',
        branches[1][1].running_var: 'aspp0/BatchNorm/moving_variance',
    }
    conversion_map.update(conversion_map_for_aspp_image_feature)
   
    # Atrous Spatial Pyramid Pooling: Atrous convolution
    for i in range(3):
        branch = branches[i+2][0]
        conversion_map_for_atrous_conv = {
            branch._conv_depthwise.weight:
                'aspp{}_depthwise/depthwise_weights'.format(i+1),
            branch._conv_pointwise.weight:
                'aspp{}_pointwise/weights'.format(i+1),
            branch._batch_norm_depthwise.bias:
                'aspp{}_depthwise/BatchNorm/beta'.format(i+1),
            branch._batch_norm_depthwise.weight:
                'aspp{}_depthwise/BatchNorm/gamma'.format(i+1),
            branch._batch_norm_depthwise.running_mean:
                'aspp{}_depthwise/BatchNorm/moving_mean'.format(i+1),
            branch._batch_norm_depthwise.running_var:
                'aspp{}_depthwise/BatchNorm/moving_variance'.format(i+1),
            branch._batch_norm_pointwise.bias:
                'aspp{}_pointwise/BatchNorm/beta'.format(i+1),
            branch._batch_norm_pointwise.weight:
                'aspp{}_pointwise/BatchNorm/gamma'.format(i+1),
            branch._batch_norm_pointwise.running_mean:
                'aspp{}_pointwise/BatchNorm/moving_mean'.format(i+1),
            branch._batch_norm_pointwise.running_var:
                'aspp{}_pointwise/BatchNorm/moving_variance'.format(i+1),
        }
        conversion_map.update(conversion_map_for_atrous_conv)
       
    # Atrous Spatial Pyramid Pooling: Concat projection
    conversion_map_for_concat_projection = {
        model._aspp._conv_concat[0].weight:
            'concat_projection/weights',
        model._aspp._conv_concat[1].bias:
            'concat_projection/BatchNorm/beta',
        model._aspp._conv_concat[1].weight:
            'concat_projection/BatchNorm/gamma',
        model._aspp._conv_concat[1].running_mean:
            'concat_projection/BatchNorm/moving_mean',
        model._aspp._conv_concat[1].running_var:
            'concat_projection/BatchNorm/moving_variance',
    }
    conversion_map.update(conversion_map_for_concat_projection)
   
    # Refine decoder: Feature projection
    conversion_map_for_decoder = {
        model._refine_decoder._decoder[0].weight:
            'decoder/feature_projection0/weights',
        model._refine_decoder._decoder[1].bias:
            'decoder/feature_projection0/BatchNorm/beta',
        model._refine_decoder._decoder[1].weight:
            'decoder/feature_projection0/BatchNorm/gamma',
        model._refine_decoder._decoder[1].running_mean:
            'decoder/feature_projection0/BatchNorm/moving_mean',
        model._refine_decoder._decoder[1].running_var:
            'decoder/feature_projection0/BatchNorm/moving_variance',
    }
    conversion_map.update(conversion_map_for_decoder)
   
    # Refine decoder: Concat
    layers = model._refine_decoder._concat_layers
    for i in range(2):
        layer = layers[i]
        conversion_map_decoder = {
            layer._conv_depthwise.weight:
                'decoder/decoder_conv{}_depthwise/depthwise_weights'.format(i),
            layer._conv_pointwise.weight:
                'decoder/decoder_conv{}_pointwise/weights'.format(i),
            layer._batch_norm_depthwise.bias:
                'decoder/decoder_conv{}_depthwise/BatchNorm/beta'.format(i),
            layer._batch_norm_depthwise.weight:
                'decoder/decoder_conv{}_depthwise/BatchNorm/gamma'.format(i),
            layer._batch_norm_depthwise.running_mean:
                'decoder/decoder_conv{}_depthwise/BatchNorm/moving_mean'.format(i),
            layer._batch_norm_depthwise.running_var:
                'decoder/decoder_conv{}_depthwise/BatchNorm/moving_variance'.format(i),
            layer._batch_norm_pointwise.bias:
                'decoder/decoder_conv{}_pointwise/BatchNorm/beta'.format(i),
            layer._batch_norm_pointwise.weight:
                'decoder/decoder_conv{}_pointwise/BatchNorm/gamma'.format(i),
            layer._batch_norm_pointwise.running_mean:
                'decoder/decoder_conv{}_pointwise/BatchNorm/moving_mean'.format(i),
            layer._batch_norm_pointwise.running_var:
                'decoder/decoder_conv{}_pointwise/BatchNorm/moving_variance'.format(i),
        }
        conversion_map.update(conversion_map_decoder)
       
    # Prediction logits
    conversion_map_for_logits = {
        model._logits_layer.weight: 'logits/semantic/weights',
        model._logits_layer.bias: 'logits/semantic/biases',
    }
    conversion_map.update(conversion_map_for_logits)
       
    # Load TensorFlow parameters into PyTorch model
    load_param(checkpoint_path, conversion_map, model._feature_extractor.scope)

????????我们以转化 xception65_coco_voc_trainval(以 Xception65 为特征提取器,在 COCO 数据集上训练得到的模型参数) 为例,该模型的 TensorFlow 官方预训练参数到此 下载。下载好解压之后,执行

1
python3 tf_weights_to_pth.py --tf_checkpoint_path "xxxx/deeplabv3_pascal_trainval/model.ckpt"

其中,tf_checkpoint_path 填写刚下载的 TensorFlow 预训练模型参数的完整路径。结束执行之后,将在当前项目路径下生成一个 pretrained_models 的文件夹,里面就是转化后的预训练参数,文件名为 deeplabv3_pascal_trainval.pth。此时,再执行

1
python3 model_test.py

将在 test 文件夹里生成测试图片的语义分割的结果,具体如下

COCO_val2014_000000000294.jpg

COCO_val2014_000000000294_seg.png

分别执行如下指令,

1
2
3
python3 model_test.py --image_path "./test/tf_vis2.jpg"
python3 model_test.py --image_path "./test/tf_vis2.jpg" --val_output_stride 16 \
    --val_atrous_rates 6 12 18

结果为:


tf_vis2.jpg

tf_vis2_seg.png(output_stride=8)

tf_vis2_seg.png(output_stride=16)

????????如果你要转化其它预训练模型参数(只支持以 Xception 模型为特征提取器),则需要指定如下的 4 个完整参数:

1
2
python3 tf_weights_to_pth.py --tf_checkpoint_path "xxxx/model.ckpt" \
    --output_dir "xxx" --output_name "xxx.pth"  --pretained_num_classes xxx

分别指定 TensorFlow 的预训练模型参数文件路径,转化后的模型参数保存路径文件夹名称,转化后的文件保存名称,当前转化模型的目标分割类别数目。这些参数的默认值设置请查看 common.py 文件 Arguments for tf_weights_to_pth.py 节段。以上转化都是使用的默认的特征提取器 Xception65,如果你下载的预训练模型是以 Xception71 为特征提取器的,则还要指定参数

1
--model_variant "xception_71"

三、使用方法

????????调用 DeepLab V3+ 模型很简单(一个使用样例见 model_test.py),类实例原型如下:

1
2
3
4
5
6
7
8
9
import model

deeplab = model.deeplab(num_classes,
                        crop_size=[513, 513],
                        atrous_rates=[12, 24, 36],
                        output_stride=8,
                        pretrained=True,
                        pretained_num_classes=21,
                        checkpoint_path='./pretrained_models/deeplabv3_pascal_trainval.pth')

其中,num_classes 指定语义分割的目标类别数目(包含背景,比如 COCO 数据集就是 num_classes = 21,Cityscapes 数据集就是 num_classes = 19);crop_size 指定模型输入的大小,atrous_rates 指定 ASPP 用的空洞卷积的空洞率,对于特征提取器 xception_65,如果 output_stride = 8,则 atrous_rates = [12, 24, 36],如果 output_stride = 16,则 atrous_rates = [6, 12, 18];output_stride 指定特征提取器空间分辨率下降的倍数,比如,对于 513 x 513 的输入,经过特征提取后,分辨率大小为 129 x 129;pretrained 指定是否导入预训练参数(由 TensorFlow 的预训练参数直接转化而来),如果导入需要指定预训练参数的路径 checkpoint_path,如果使用的是默认的转化,则直接用默认设置即可;pretrained_num_classes 指定预训练参数对应模型的 num_classes,比如 COCO 是 21,Cityscapes 是 19。

????????【注】如果你设置 pretrained = True,且成功调用了模型,但预训练模型的类别数目(pretained_num_classes)与你实际使用时的类别数目(num_classes ) 不相等,那么最后一个卷积层将随机初始化。

四、训练案例

????????本节 训练数据 来源于 爱分割 公司开源的 数据集,总共包含 34426 张图片和对应的 alpha 通道。所有的图片都是模特上身图片,如:

爱分割开源数据集实例图片

图片对应的 alpha 通道虽然质量不够好,但用来训练语义分割模型已经绰绰有余。实际使用时,将 alpha 通道中值大于 50 的像素位置赋值为 1,其余位置赋值为 0,便得到对应的掩码(mask),因为我们只有模特一个目标,所有只需要区分模特背景这两个类,其中 1 对应模特,0 对应背景(因此 num_classes=2)。

数据准备

????????当你下载好爱分割开源数据集(并解压)之后,我们需要一次性将所有图片的掩码(mask)都准备好,因此你需要打开 data/retrieve.py 文件,将 root_dir 改成你的 Matting_Human_Half 文件夹的路径,然后执行 retrieve.py 等待生成所有图片的 alpha 和 mask(在 Matting_Human_Half 文件夹内),以及用于训练的 train.txtval.txt(在 data 文件夹内,其中默认随机选择 100 张图像用于验证)。假如,你训练时不再改动 Matting_Human_Half 文件夹的路径,那么你不需要再做其它处理了。如果你训练时,Matting_Human_Half 与以上制作 train.txt 和 val.txt 时指定的 root_dir 路径不一致了,那么你可以使用诸如 Notepad ++ 之类的工具,将 root_dir 替换为空,形成如下的形式:

去掉 root_dir 的标注文件

????????train.txt 和 val.txt 分别记录了训练和验证图像的路径,每一行对应一张图像的 4 路径,分别是 **原图像路径(3 通道)、透明图路径(4 通道)、alpha 通道图像路径、mask 路径,它们通过 @ 符号分隔。

训练

????????直接在命令行执行:

1
python3 train.py --root_dir "xxx/Matting_Human_Half" [--gpu_indices 0 1 ...]

开始训练,如果你从制作数据时开始, Matting_Human_Half 这个文件夹的路径始终没有改动过,那么 root_dir 这个参数也可以不指定(指定也无妨)。后面的 [--gpu_indices ...] 表示需要根据实际情况,可选的指定可用的 GPU 下标,这里默认是使用 0,1,2,3 共 4 块 GPU,如果你使用一块 GPU,则指定

1
--gpu_indices 0

如果使用多块,比如使用 第 1 块和第 3 块 GPU,则指定

1
--gpu_indices 1 3

即可。其它类似。训练过程中的所有超参数都在 common.py 文件的 Arguments for train.py 部分。

????????训练开始几分钟后,你在项目路径下执行:

1
tensorboard --logdir ./models/logs

可以打开浏览器查看训练的学习率、损失曲线,和训练过程中的分割结果图像。这里使用的是 Pytorch 自带的类:from torch.utils.tensorboard import SummaryWriter 来调用 tensorboard,因此需要 Pytorch 1.1.0 以及之后的版本才可以。(但好像浏览器刷新不了新结果,需要不断重开 tensorboard 才可以观看训练进展)

????????训练结束后(默认只训练 3 个 epoch),在 models 文件夹中保存了训练过程中的模型参数文件(模型使用参考 predict.py)。直接执行:

1
python3 predict.py

将在 data/test 文件夹里生成测试图片的分隔结果,如下:

测试图片分割结果

要测试其它图片的分割结果,请修改 predict.py 文件中的图片路径。

其它数据集上训练

????????训练 Pytorch 模型时,需要重载 torch.utils.data.Dataset 类,用来提供数据的批量生成。在这个项目里,已经写好了基本的重载类 SegDataset,见 core/dataset.py 文件,该类中已在函数 __ getitem __ 中写好了诸如随机缩放、填充、裁剪、水平翻转等数据增强操作,唯一需要重载(补全)的是 get_image_mask_paths 这个抽象函数(空函数),这个函数需要返回一个如下格式的列表

1
2
3
4
[[image_path, mask_path],
 [image_path, mask_path],
...
 [image_path, mask_path]]

以爱分割开源数据为例,我们最后使用的类是项目下的 dataset_matting.py 文件中的 MattingDataset,它继承自以上的 **SegDataset ** 类。如开头所说的,一般我们只需要把空函数 get_image_mask_paths 补全就可以了,但是因为这个数据集特别规范,完全不需要我们做任何额外的数据增强,而只需要简单的缩放到固定大小,因此还简单的又一次重载了 __ getitem __ 函数。

????????对于其它的数据集,基本只需要根据自己数据的标注格式,继承一下 core/dataset.py 文件的类 **SegDataset **,重载(补全) get_image_mask_paths 这个空函数即可。如果你还需要额外的数据增强操作,那么再重载一下 __ getitem __ 函数。

????????训练过程中的诸如学习率、批量大小等所有超参数既可以在训练时通过命令行指定,也可以在 common.py 文件的 Arguments for train.py 部分直接修改默认值。注意:num_classes = 所有有效目标类个数 + 1,即需要加上背景这个类,背景一般对应 mask 中的类标号 0。