????????本文将使用 Pytorch 实现 DeepLab V3+ 语义分割模型,同时,为了不耗费时间训练却能使用已有的预训练参数,将从 DeepLab V3+ 模型的 官方 TensorFlow 实现 转化预训练参数。本文深受 DeepLab V3+ 官方实现的影响,很大程度上可以看成是它的 Pytorch 版翻译。
????????本文的 DeepLab V3+ 模型用的特征提取器(即后端)是 Xception 模型(见上一篇文章Pytorch 实现 Xception 模型,并从 TensorFlow 直接转化预训练参数,有部分修改,加入了保存中间特征的 end_points,便于 DeepLab V3+ 模型调用)。
????????所有代码见 Github: deeplabv3_plus。Pytorch 版本需要 1.1.0 及 之后的版本,因为训练时需要用到 Pytorch 自带的 tensorboard 调用类。
一、实现细节
????????DeepLab V3+ 的原理将在后续的文章详细介绍,这里略过。实现如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 | # -*- coding: utf-8 -*- """ Created on Mon Sep 16 14:24:03 2019 @author: shirhe-lyh Implementation of DeepLabV3+: Encoder-Decoder with atrous separable convolution for semantic image segmentation, Liang-Chieh Chen, YuKun Zhu, George Papandreou, Florian Schroff, Hartwig Adam, arxiv:1802.02611 (https://arxiv.org/abs/1802.02611). Official implementation: https://github.com/tensorflow/models/tree/master/research/deeplab """ import os import torch import common import core.feature_extractor as extractor _BATCH_NORM_PARAMS = { 'momentum': 0.9997, 'eps': 1e-5, 'affine': True, } class DeepLab(torch.nn.Module): """Implementation of DeepLab V3+.""" def __init__(self, feature_extractor, model_options): """Constructor. Args: feature_extractor: The backbone of DeepLab model. model_options: A ModelOptions instance to configure models. """ super(DeepLab, self).__init__() self._model_options = model_options # Feature extractor self._feature_extractor = feature_extractor # Atrous spatial pyramid pooling self._aspp = AtrousSpatialPyramidPooling( in_channels=feature_extractor.out_channels, crop_size=model_options.crop_size, output_stride=model_options.output_stride, atrous_rates=model_options.atrous_rates, use_bounded_activation=model_options.use_bounded_activation, add_image_level_feature=model_options.add_image_level_feature, image_pooling_stride=model_options.image_pooling_stride, image_pooling_crop_size=model_options.image_pooling_crop_size, aspp_with_separable_conv=model_options.aspp_with_separable_conv) # Refine by decoder self._refine_decoder = None if model_options.decoder_output_stride: self._refine_decoder = RefineDecoder( feature_extractor=self._feature_extractor, crop_size=model_options.crop_size, decoder_output_stride=model_options.decoder_output_stride[0], decoder_use_separable_conv=model_options.decoder_use_separable_conv, model_variant=model_options.model_variant, use_bounded_activation=model_options.use_bounded_activation) # Branch logits num_classes = model_options.outputs_to_num_classes[common.OUTPUT_TYPE] self._logits_layer = torch.nn.Conv2d( in_channels=256, out_channels=num_classes, kernel_size=model_options.logits_kernel_size) def forward(self, x): features = self._feature_extractor(x) features = self._aspp(features) if self._refine_decoder is not None: features = self._refine_decoder(features) logits = self._logits_layer(features) return logits class AtrousSpatialPyramidPooling(torch.nn.Module): """Atrous Spatial Pyramid Pooling.""" def __init__(self, in_channels, out_channels=256, output_stride=16, crop_size=[513, 513], atrous_rates=[12, 24, 36], use_bounded_activation=False, add_image_level_feature=True, image_pooling_stride=[1, 1], image_pooling_crop_size=None, aspp_with_separable_conv=True): """Constructor. Args: in_channels: Number of input filters. out_channels: Number of filters in the 1x1 pointwise convolution. atrous_rates: A list of atrous convolution rates for ASPP. use_bounded_activation: Whether or not to use bounded activations. crop_size: A tuple [crop_height, crop_width]. image_pooling_crop_size: Image pooling crop size [height, width] used in the ASPP module. """ super(AtrousSpatialPyramidPooling, self).__init__() activation_fn = ( torch.nn.ReLU6(inplace=False) if use_bounded_activation else torch.nn.ReLU(inplace=False)) depth = out_channels branches = [] if add_image_level_feature: layers = [] if crop_size is not None: # If image_pooling_crop_size is not specified, use crop_size if image_pooling_crop_size is None: image_pooling_crop_size = crop_size pool_height = scale_dimension(image_pooling_crop_size[0], 1. / output_stride) pool_width = scale_dimension(image_pooling_crop_size[1], 1. / output_stride) layers += [torch.nn.AvgPool2d( (pool_height, pool_width), image_pooling_stride)] resize_height = scale_dimension( crop_size[0], 1. / output_stride) resize_width = scale_dimension( crop_size[1], 1. / output_stride) else: # If crop_size is None, we simply do global pooling layers += [torch.nn.AdaptiveAvgPool2d(output_size=(1, 1))] resize_height, resize_width = None, None self._resize_height = resize_height self._resize_width = resize_width layers += [torch.nn.Conv2d(in_channels, depth, 1, bias=False), torch.nn.BatchNorm2d(depth, **_BATCH_NORM_PARAMS), activation_fn] branches.append(torch.nn.Sequential(*layers)) # Employ a 1x1 convolution. branches.append(torch.nn.Sequential( torch.nn.Conv2d(in_channels, depth, kernel_size=1, bias=False), torch.nn.BatchNorm2d(depth, **_BATCH_NORM_PARAMS), activation_fn)) if atrous_rates: # Employ 3x3 convolutions with different atrous rates. for i, rate in enumerate(atrous_rates): layers = [] if aspp_with_separable_conv: layers += [ SplitSeparableConv2d(in_channels, depth, rate=rate, activation_fn=activation_fn)] else: layers += [ torch.nn.Conv2d(in_channels, depth, kernel_size=3, rate=rate, padding=1, bias=False), torch.nn.BatchNorm2d(depth, **_BATCH_NORM_PARAMS), activation_fn] branches.append(torch.nn.Sequential(*layers)) self._branches = torch.nn.Sequential(*branches) # Merge branch logits self._conv_concat = torch.nn.Sequential( torch.nn.Conv2d(len(self._branches) * depth, depth, kernel_size=1, bias=False), torch.nn.BatchNorm2d(depth, **_BATCH_NORM_PARAMS), activation_fn) self._dropout = torch.nn.Dropout2d(p=0.9, inplace=True) def forward(self, x): branch_logits = [] conv_branches = self._branches if len(self._branches) > 4: conv_branches = self._branches[1:] image_feature = self._branches[0](x) if self._resize_height is None: _, _, height, width = x.size() self._resize_height, self._resize_width = height, width image_feature = resize_bilinear( image_feature, (self._resize_height, self._resize_width)) branch_logits.append(image_feature) branch_logits += [branch(x) for branch in conv_branches] cancat_logits = torch.cat(branch_logits, dim=1) cancat_logits = self._conv_concat(cancat_logits) cancat_logits = self._dropout(cancat_logits) return cancat_logits class SplitSeparableConv2d(torch.nn.Module): """Splits a seperable conv2d into depthwise and pointwise conv2d.""" def __init__(self, in_channels, out_channels, kernel_size=3, rate=1, use_batch_norm=True, activation_fn=None): """Constructor. Args: in_channels: Number of input filters. out_channels: Number of filters in the 1x1 pointwise convolution. kernel_size: A list of length 2: [kernel_height, kernel_width] of the filters. Can be an int if both values are the same. rate: Atrous convolution rate for the depthwise convolution. with_batch_norm: Whether or not to use batch normalization. activation_fn: The activation function to be applied. """ super(SplitSeparableConv2d, self).__init__() # For the shape of output of Conv2d, see details at: # https://pytorch.org/docs/stable/nn.html#convolution-layers # Here, we assume that floor(padding) = padding padding = (kernel_size - 1) * rate // 2 self._conv_depthwise = torch.nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, dilation=rate, padding=padding, groups=in_channels, bias=False) self._conv_pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self._batch_norm_depthwise = None self._batch_norm_pointwise = None if use_batch_norm: self._batch_norm_depthwise = torch.nn.BatchNorm2d( num_features=in_channels, **_BATCH_NORM_PARAMS) self._batch_norm_pointwise = torch.nn.BatchNorm2d( num_features=out_channels, **_BATCH_NORM_PARAMS) self._activation_fn = activation_fn def forward(self, x): x = self._conv_depthwise(x) if self._batch_norm_depthwise is not None: x = self._batch_norm_depthwise(x) if self._activation_fn is not None: x = self._activation_fn(x) x = self._conv_pointwise(x) if self._batch_norm_pointwise is not None: x = self._batch_norm_pointwise(x) if self._activation_fn is not None: x = self._activation_fn(x) return x class RefineDecoder(torch.nn.Module): """Adds the decoder to obtain sharper segmentation results.""" def __init__(self, feature_extractor, aspp_channels=256, crop_size=None, decoder_output_stride=None, decoder_use_separable_conv=False, model_variant=None, use_bounded_activation=False): """Constructor. Args: feature_extractor: The backbone of the DeepLab model. aspp_channels: The out channels of ASPP. crop_size: A tuple [crop_height, crop_width] specifying whole patch crop size. decoder_output_stride: An integer specifying the output stride of low-level features used in the decoder module. decoder_use_separable_conv: Employ separable convolution for decoder or not. model_variant: Model variant for feature extractor. use_bounded_activation: Whether or not to use bounded activations. Bounded activations better lend themselves to quantized inference. Raises: ValueError: If crop_size is None. """ super(RefineDecoder, self).__init__() if crop_size is None: raise ValueError('crop_size must be provided when using decoder.') self._crop_size = crop_size self._output_stride = decoder_output_stride activation_fn = ( torch.nn.ReLU6(inplace=False) if use_bounded_activation else torch.nn.ReLU(inplace=False)) self._extractor = feature_extractor feature_names = extractor.networks_to_feature_maps[ model_variant][extractor.DECODER_END_POINTS][decoder_output_stride] self._feature_name = '{}/{}'.format(model_variant, feature_names[0]) self._decoder = torch.nn.Sequential( torch.nn.Conv2d(extractor.feature_out_channels_map[model_variant], out_channels=48, kernel_size=1, bias=False), torch.nn.BatchNorm2d(num_features=48, **_BATCH_NORM_PARAMS), activation_fn) concat_layers = [] decoder_depth = 256 if decoder_use_separable_conv: concat_layers += [ SplitSeparableConv2d(in_channels=aspp_channels + 48, out_channels=decoder_depth, rate=1, activation_fn=activation_fn), SplitSeparableConv2d(in_channels=decoder_depth, out_channels=decoder_depth, rate=1, activation_fn=activation_fn)] else: concat_layers += [ torch.nn.Conv2d(in_channels=aspp_channels + 48, out_channels=decoder_depth, kernel_size=3, padding=1, bias=False), torch.nn.BatchNorm2d(num_features=decoder_depth, **_BATCH_NORM_PARAMS), activation_fn, torch.nn.Conv2d(in_channels=decoder_depth, out_channels=decoder_depth, kernel_size=3, padding=1, bias=False), torch.nn.BatchNorm2d(num_features=decoder_depth, **_BATCH_NORM_PARAMS), activation_fn] self._concat_layers = torch.nn.Sequential(*concat_layers) def forward(self, x): decoder_features_list = [x] inter_feature = self._extractor.end_points()[self._feature_name] decoder_features_list.append(self._decoder(inter_feature)) # Determine the output size decoder_height = scale_dimension(self._crop_size[0], 1.0 / self._output_stride) decoder_width = scale_dimension(self._crop_size[1], 1.0 / self._output_stride) # Resize to decoder_height/decoder_width for j, feature in enumerate(decoder_features_list): decoder_features_list[j] = resize_bilinear( feature, (decoder_height, decoder_width)) x = self._concat_layers(torch.cat(decoder_features_list, dim=1)) return x def scale_dimension(dim, scale): """Scales the input dimension. Args: dim: Input dimension (a scalar). scale: The amount of scaling applied to the input. Returns: scaled dimension. """ return int((float(dim) - 1.0) * scale + 1.0) def resize_bilinear(images, size): """Returns resized images. Args: images: A tensor of size [batch, height_in, width_in, channels]. size: A tuple (height, width). Returns: A tensor of shape [batch, height_out, height_width, channels]. """ return torch.nn.functional.interpolate( images, size, mode='bilinear', align_corners=True) def deeplab(num_classes, crop_size=[513, 513], atrous_rates=[12, 24, 36], output_stride=16, pretrained=True, pretained_num_classes=21, checkpoint_path='./pretrained_models/deeplabv3_pascal_trainval.pth'): """DeepLab v3+ for semantic segmentation.""" outputs_to_num_classes = {'semantic': num_classes} model_options = common.ModelOptions(outputs_to_num_classes, crop_size=crop_size, atrous_rates=atrous_rates, output_stride=output_stride) feature_extractor = extractor.feature_extractor( model_options.model_variant, pretrained=False, output_stride=model_options.output_stride) model = DeepLab(feature_extractor, model_options) if pretrained: _load_state_dict(model, num_classes, pretained_num_classes, checkpoint_path) return model def _load_state_dict(model, num_classes, pretained_num_classes, checkpoint_path): """Load pretrained weights.""" if os.path.exists(checkpoint_path): state_dict = torch.load(checkpoint_path) if num_classes is None or num_classes != pretained_num_classes: state_dict.pop('_logits_layer.weight') state_dict.pop('_logits_layer.bias') model.load_state_dict(state_dict, strict=False) print('Load pretrained weights successfully.') else: raise ValueError('`checkpoint_path` does not exist.') |
二、预训练参数转化
????????将 TensorFlow 官方的预训练模型参数转化成 Pytorch 版,这里采用完全的手动指定的方式(Pytorch 的缺陷——不能指定变量名称导致),即对于每一个参数,从 TensorFlow 参数里取出对应的参数值,然后赋值给它。详细如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 | # -*- coding: utf-8 -*- """ Created on Wed Oct 9 17:46:13 2019 @author: shirhe-lyh Convert tensorflow weights to pytorch weights for DeepLab V3+ models. Reference: https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/ tf_to_pytorch/convert_tf_to_pt/load_tf_weights.py """ import numpy as np import tensorflow as tf import torch _BLOCK_UNIT_COUNT_MAP = { 'xception_41': [[3, 1], [1, 8], [2, 1]], 'xception_65': [[3, 1], [1, 16], [2, 1]], 'xception_71': [[5, 1], [1, 16], [2, 1]], } def load_param(checkpoint_path, conversion_map, model_name): """Load parameters according to conversion_map. Args: checkpoint_path: Path to tensorflow's checkpoint file. conversion_map: A dictionary with format {pytorch tensor in a model: checkpoint variable name} model_name: The name of Xception model, only supports 'xception_41', 'xception_65', or 'xception_71'. """ for pth_param, tf_param_name in conversion_map.items(): param_name_strs = tf_param_name.split('_') if len(param_name_strs) > 1 and param_name_strs[1].startswith('flow'): tf_param_name = str(model_name) + '/' + tf_param_name tf_param = tf.train.load_variable(checkpoint_path, tf_param_name) if 'conv' in tf_param_name and 'weights' in tf_param_name: tf_param = np.transpose(tf_param, (3, 2, 0, 1)) if 'depthwise' in tf_param_name: tf_param = np.transpose(tf_param, (1, 0, 2, 3)) elif 'depthwise_weights' in tf_param_name: tf_param = np.transpose(tf_param, (3, 2, 0, 1)) tf_param = np.transpose(tf_param, (1, 0, 2, 3)) elif tf_param_name.endswith('weights'): tf_param = np.transpose(tf_param) assert pth_param.size() == tf_param.shape, ('Dimension mismatch: ' + '{} vs {}; {}'.format(pth_param.size(), tf_param.shape, tf_param_name)) pth_param.data = torch.from_numpy(tf_param) def convert(model, checkpoint_path): """Load Pytorch Xception from TensorFlow checkpoint file. Args: model: The pytorch Xception model, only supports 'xception_41', 'xception_65', or 'xception_71'. checkpoint_path: Path to tensorflow's checkpoint file. Returns: The pytorch Xception model with pretrained parameters. """ block_unit_counts = _BLOCK_UNIT_COUNT_MAP.get( model._feature_extractor.scope, None) if block_unit_counts is None: raise ValueError('Unsupported Xception model name.') flow_names = [] block_indices = [] unit_indices = [] flow_names_unique = ['entry_flow', 'middle_flow', 'exit_flow'] for i, [block_count, unit_count] in enumerate(block_unit_counts): flow_names += [flow_names_unique[i]] * (block_count * unit_count) for i in range(block_count): block_indices += [i + 1] * unit_count unit_indices += [j + 1 for j in range(unit_count)] conversion_map = {} # Feature extractor: Root block conversion_map_for_root_block = { model._feature_extractor._layers[0]._conv.weight: 'entry_flow/conv1_1/weights', model._feature_extractor._layers[0]._batch_norm.bias: 'entry_flow/conv1_1/BatchNorm/beta', model._feature_extractor._layers[0]._batch_norm.weight: 'entry_flow/conv1_1/BatchNorm/gamma', model._feature_extractor._layers[0]._batch_norm.running_mean: 'entry_flow/conv1_1/BatchNorm/moving_mean', model._feature_extractor._layers[0]._batch_norm.running_var: 'entry_flow/conv1_1/BatchNorm/moving_variance', model._feature_extractor._layers[1]._conv.weight: 'entry_flow/conv1_2/weights', model._feature_extractor._layers[1]._batch_norm.bias: 'entry_flow/conv1_2/BatchNorm/beta', model._feature_extractor._layers[1]._batch_norm.weight: 'entry_flow/conv1_2/BatchNorm/gamma', model._feature_extractor._layers[1]._batch_norm.running_mean: 'entry_flow/conv1_2/BatchNorm/moving_mean', model._feature_extractor._layers[1]._batch_norm.running_var: 'entry_flow/conv1_2/BatchNorm/moving_variance', } conversion_map.update(conversion_map_for_root_block) # Feature extractor: Xception block for i in range(len(model._feature_extractor._layers[2]._blocks)): block = model._feature_extractor._layers[2]._blocks[i] ind = [1, 3, 5] if len(block._separable_conv_block) < 6: ind = [0, 1, 2] for j in range(3): conversion_map_for_separable_block = { block._separable_conv_block[ind[j]]._conv_depthwise.weight: ('{}/block{}/unit_{}/xception_module/' + 'separable_conv{}_depthwise/depthwise_weights').format( flow_names[i], block_indices[i], unit_indices[i], j+1), block._separable_conv_block[ind[j]]._conv_pointwise.weight: ('{}/block{}/unit_{}/xception_module/' + 'separable_conv{}_pointwise/weights').format( flow_names[i], block_indices[i], unit_indices[i], j+1), block._separable_conv_block[ind[j]]._batch_norm_depthwise.bias: ('{}/block{}/unit_{}/xception_module/' + 'separable_conv{}_depthwise/BatchNorm/beta').format( flow_names[i], block_indices[i], unit_indices[i], j+1), block._separable_conv_block[ind[j]]._batch_norm_depthwise.weight: ('{}/block{}/unit_{}/xception_module/' + 'separable_conv{}_depthwise/BatchNorm/gamma').format( flow_names[i], block_indices[i], unit_indices[i], j+1), block._separable_conv_block[ind[j]]._batch_norm_depthwise.running_mean: ('{}/block{}/unit_{}/xception_module/' + 'separable_conv{}_depthwise/BatchNorm/moving_mean').format( flow_names[i], block_indices[i], unit_indices[i], j+1), block._separable_conv_block[ind[j]]._batch_norm_depthwise.running_var: ('{}/block{}/unit_{}/xception_module/' + 'separable_conv{}_depthwise/BatchNorm/moving_variance').format( flow_names[i], block_indices[i], unit_indices[i], j+1), block._separable_conv_block[ind[j]]._batch_norm_pointwise.bias: ('{}/block{}/unit_{}/xception_module/' + 'separable_conv{}_pointwise/BatchNorm/beta').format( flow_names[i], block_indices[i], unit_indices[i], j+1), block._separable_conv_block[ind[j]]._batch_norm_pointwise.weight: ('{}/block{}/unit_{}/xception_module/' + 'separable_conv{}_pointwise/BatchNorm/gamma').format( flow_names[i], block_indices[i], unit_indices[i], j+1), block._separable_conv_block[ind[j]]._batch_norm_pointwise.running_mean: ('{}/block{}/unit_{}/xception_module/' + 'separable_conv{}_pointwise/BatchNorm/moving_mean').format( flow_names[i], block_indices[i], unit_indices[i], j+1), block._separable_conv_block[ind[j]]._batch_norm_pointwise.running_var: ('{}/block{}/unit_{}/xception_module/' + 'separable_conv{}_pointwise/BatchNorm/moving_variance').format( flow_names[i], block_indices[i], unit_indices[i], j+1), } conversion_map.update(conversion_map_for_separable_block) if getattr(block, '_conv_skip_connection', None) is not None: conversion_map_for_shortcut = { block._conv_skip_connection.weight: ('{}/block{}/unit_{}/xception_module/shortcut/' + 'weights').format( flow_names[i], block_indices[i], unit_indices[i]), block._batch_norm_shortcut.bias: ('{}/block{}/unit_{}/xception_module/shortcut/' + 'BatchNorm/beta').format( flow_names[i], block_indices[i], unit_indices[i]), block._batch_norm_shortcut.weight: ('{}/block{}/unit_{}/xception_module/shortcut/' + 'BatchNorm/gamma').format( flow_names[i], block_indices[i], unit_indices[i]), block._batch_norm_shortcut.running_mean: ('{}/block{}/unit_{}/xception_module/shortcut/' + 'BatchNorm/moving_mean').format( flow_names[i], block_indices[i], unit_indices[i]), block._batch_norm_shortcut.running_var: ('{}/block{}/unit_{}/xception_module/shortcut/' + 'BatchNorm/moving_variance').format( flow_names[i], block_indices[i], unit_indices[i]), } conversion_map.update(conversion_map_for_shortcut) # Atrous Spatial Pyramid Pooling: Image feature branches = model._aspp._branches conversion_map_for_aspp_image_feature = { branches[0][1].weight: 'image_pooling/weights', branches[0][2].bias: 'image_pooling/BatchNorm/beta', branches[0][2].weight: 'image_pooling/BatchNorm/gamma', branches[0][2].running_mean: 'image_pooling/BatchNorm/moving_mean', branches[0][2].running_var: 'image_pooling/BatchNorm/moving_variance', branches[1][0].weight: 'aspp0/weights', branches[1][1].bias: 'aspp0/BatchNorm/beta', branches[1][1].weight: 'aspp0/BatchNorm/gamma', branches[1][1].running_mean: 'aspp0/BatchNorm/moving_mean', branches[1][1].running_var: 'aspp0/BatchNorm/moving_variance', } conversion_map.update(conversion_map_for_aspp_image_feature) # Atrous Spatial Pyramid Pooling: Atrous convolution for i in range(3): branch = branches[i+2][0] conversion_map_for_atrous_conv = { branch._conv_depthwise.weight: 'aspp{}_depthwise/depthwise_weights'.format(i+1), branch._conv_pointwise.weight: 'aspp{}_pointwise/weights'.format(i+1), branch._batch_norm_depthwise.bias: 'aspp{}_depthwise/BatchNorm/beta'.format(i+1), branch._batch_norm_depthwise.weight: 'aspp{}_depthwise/BatchNorm/gamma'.format(i+1), branch._batch_norm_depthwise.running_mean: 'aspp{}_depthwise/BatchNorm/moving_mean'.format(i+1), branch._batch_norm_depthwise.running_var: 'aspp{}_depthwise/BatchNorm/moving_variance'.format(i+1), branch._batch_norm_pointwise.bias: 'aspp{}_pointwise/BatchNorm/beta'.format(i+1), branch._batch_norm_pointwise.weight: 'aspp{}_pointwise/BatchNorm/gamma'.format(i+1), branch._batch_norm_pointwise.running_mean: 'aspp{}_pointwise/BatchNorm/moving_mean'.format(i+1), branch._batch_norm_pointwise.running_var: 'aspp{}_pointwise/BatchNorm/moving_variance'.format(i+1), } conversion_map.update(conversion_map_for_atrous_conv) # Atrous Spatial Pyramid Pooling: Concat projection conversion_map_for_concat_projection = { model._aspp._conv_concat[0].weight: 'concat_projection/weights', model._aspp._conv_concat[1].bias: 'concat_projection/BatchNorm/beta', model._aspp._conv_concat[1].weight: 'concat_projection/BatchNorm/gamma', model._aspp._conv_concat[1].running_mean: 'concat_projection/BatchNorm/moving_mean', model._aspp._conv_concat[1].running_var: 'concat_projection/BatchNorm/moving_variance', } conversion_map.update(conversion_map_for_concat_projection) # Refine decoder: Feature projection conversion_map_for_decoder = { model._refine_decoder._decoder[0].weight: 'decoder/feature_projection0/weights', model._refine_decoder._decoder[1].bias: 'decoder/feature_projection0/BatchNorm/beta', model._refine_decoder._decoder[1].weight: 'decoder/feature_projection0/BatchNorm/gamma', model._refine_decoder._decoder[1].running_mean: 'decoder/feature_projection0/BatchNorm/moving_mean', model._refine_decoder._decoder[1].running_var: 'decoder/feature_projection0/BatchNorm/moving_variance', } conversion_map.update(conversion_map_for_decoder) # Refine decoder: Concat layers = model._refine_decoder._concat_layers for i in range(2): layer = layers[i] conversion_map_decoder = { layer._conv_depthwise.weight: 'decoder/decoder_conv{}_depthwise/depthwise_weights'.format(i), layer._conv_pointwise.weight: 'decoder/decoder_conv{}_pointwise/weights'.format(i), layer._batch_norm_depthwise.bias: 'decoder/decoder_conv{}_depthwise/BatchNorm/beta'.format(i), layer._batch_norm_depthwise.weight: 'decoder/decoder_conv{}_depthwise/BatchNorm/gamma'.format(i), layer._batch_norm_depthwise.running_mean: 'decoder/decoder_conv{}_depthwise/BatchNorm/moving_mean'.format(i), layer._batch_norm_depthwise.running_var: 'decoder/decoder_conv{}_depthwise/BatchNorm/moving_variance'.format(i), layer._batch_norm_pointwise.bias: 'decoder/decoder_conv{}_pointwise/BatchNorm/beta'.format(i), layer._batch_norm_pointwise.weight: 'decoder/decoder_conv{}_pointwise/BatchNorm/gamma'.format(i), layer._batch_norm_pointwise.running_mean: 'decoder/decoder_conv{}_pointwise/BatchNorm/moving_mean'.format(i), layer._batch_norm_pointwise.running_var: 'decoder/decoder_conv{}_pointwise/BatchNorm/moving_variance'.format(i), } conversion_map.update(conversion_map_decoder) # Prediction logits conversion_map_for_logits = { model._logits_layer.weight: 'logits/semantic/weights', model._logits_layer.bias: 'logits/semantic/biases', } conversion_map.update(conversion_map_for_logits) # Load TensorFlow parameters into PyTorch model load_param(checkpoint_path, conversion_map, model._feature_extractor.scope) |
????????我们以转化 xception65_coco_voc_trainval(以 Xception65 为特征提取器,在 COCO 数据集上训练得到的模型参数) 为例,该模型的 TensorFlow 官方预训练参数到此 下载。下载好解压之后,执行
1 | python3 tf_weights_to_pth.py --tf_checkpoint_path "xxxx/deeplabv3_pascal_trainval/model.ckpt" |
其中,tf_checkpoint_path 填写刚下载的 TensorFlow 预训练模型参数的完整路径。结束执行之后,将在当前项目路径下生成一个 pretrained_models 的文件夹,里面就是转化后的预训练参数,文件名为 deeplabv3_pascal_trainval.pth。此时,再执行
1 | python3 model_test.py |
将在 test 文件夹里生成测试图片的语义分割的结果,具体如下

COCO_val2014_000000000294.jpg

COCO_val2014_000000000294_seg.png
分别执行如下指令,
1 2 3 | python3 model_test.py --image_path "./test/tf_vis2.jpg" python3 model_test.py --image_path "./test/tf_vis2.jpg" --val_output_stride 16 \ --val_atrous_rates 6 12 18 |
结果为:

tf_vis2.jpg

tf_vis2_seg.png(output_stride=8)

tf_vis2_seg.png(output_stride=16)
????????如果你要转化其它预训练模型参数(只支持以 Xception 模型为特征提取器),则需要指定如下的 4 个完整参数:
1 2 | python3 tf_weights_to_pth.py --tf_checkpoint_path "xxxx/model.ckpt" \ --output_dir "xxx" --output_name "xxx.pth" --pretained_num_classes xxx |
分别指定 TensorFlow 的预训练模型参数文件路径,转化后的模型参数保存路径文件夹名称,转化后的文件保存名称,当前转化模型的目标分割类别数目。这些参数的默认值设置请查看
1 | --model_variant "xception_71" |
三、使用方法
????????调用 DeepLab V3+ 模型很简单(一个使用样例见
1 2 3 4 5 6 7 8 9 | import model deeplab = model.deeplab(num_classes, crop_size=[513, 513], atrous_rates=[12, 24, 36], output_stride=8, pretrained=True, pretained_num_classes=21, checkpoint_path='./pretrained_models/deeplabv3_pascal_trainval.pth') |
其中,num_classes 指定语义分割的目标类别数目(包含背景,比如 COCO 数据集就是 num_classes = 21,Cityscapes 数据集就是 num_classes = 19);crop_size 指定模型输入的大小,atrous_rates 指定 ASPP 用的空洞卷积的空洞率,对于特征提取器
????????【注】如果你设置 pretrained = True,且成功调用了模型,但预训练模型的类别数目(pretained_num_classes)与你实际使用时的类别数目(num_classes ) 不相等,那么最后一个卷积层将随机初始化。
四、训练案例
????????本节 训练数据 来源于 爱分割 公司开源的 数据集,总共包含 34426 张图片和对应的 alpha 通道。所有的图片都是模特上身图片,如:

爱分割开源数据集实例图片
图片对应的 alpha 通道虽然质量不够好,但用来训练语义分割模型已经绰绰有余。实际使用时,将 alpha 通道中值大于 50 的像素位置赋值为 1,其余位置赋值为 0,便得到对应的掩码(mask),因为我们只有模特一个目标,所有只需要区分模特和背景这两个类,其中 1 对应模特,0 对应背景(因此 num_classes=2)。
数据准备
????????当你下载好爱分割开源数据集(并解压)之后,我们需要一次性将所有图片的掩码(mask)都准备好,因此你需要打开

去掉 root_dir 的标注文件
????????train.txt 和 val.txt 分别记录了训练和验证图像的路径,每一行对应一张图像的 4 路径,分别是 **原图像路径(3 通道)、透明图路径(4 通道)、alpha 通道图像路径、mask 路径,它们通过 @ 符号分隔。
训练
????????直接在命令行执行:
1 | python3 train.py --root_dir "xxx/Matting_Human_Half" [--gpu_indices 0 1 ...] |
开始训练,如果你从制作数据时开始, Matting_Human_Half 这个文件夹的路径始终没有改动过,那么 root_dir 这个参数也可以不指定(指定也无妨)。后面的 [--gpu_indices ...] 表示需要根据实际情况,可选的指定可用的 GPU 下标,这里默认是使用 0,1,2,3 共 4 块 GPU,如果你使用一块 GPU,则指定
1 | --gpu_indices 0 |
如果使用多块,比如使用 第 1 块和第 3 块 GPU,则指定
1 | --gpu_indices 1 3 |
即可。其它类似。训练过程中的所有超参数都在
????????训练开始几分钟后,你在项目路径下执行:
1 | tensorboard --logdir ./models/logs |
可以打开浏览器查看训练的学习率、损失曲线,和训练过程中的分割结果图像。这里使用的是 Pytorch 自带的类:
????????训练结束后(默认只训练 3 个 epoch),在 models 文件夹中保存了训练过程中的模型参数文件(模型使用参考
1 | python3 predict.py |
将在 data/test 文件夹里生成测试图片的分隔结果,如下:

测试图片分割结果
要测试其它图片的分割结果,请修改
其它数据集上训练
????????训练 Pytorch 模型时,需要重载 torch.utils.data.Dataset 类,用来提供数据的批量生成。在这个项目里,已经写好了基本的重载类 SegDataset,见
1 2 3 4 | [[image_path, mask_path], [image_path, mask_path], ... [image_path, mask_path]] |
以爱分割开源数据为例,我们最后使用的类是项目下的
????????对于其它的数据集,基本只需要根据自己数据的标注格式,继承一下
????????训练过程中的诸如学习率、批量大小等所有超参数既可以在训练时通过命令行指定,也可以在