diff --git a/mindspore/explainer/ood/__init__.py b/mindspore/explainer/ood/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05359aada047b8ea43138c5a0837abd5018ce7a9 --- /dev/null +++ b/mindspore/explainer/ood/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Out Of Distribution Network.""" diff --git a/mindspore/explainer/ood/ood_net.py b/mindspore/explainer/ood/ood_net.py new file mode 100644 index 0000000000000000000000000000000000000000..d927eb9ee57456bb9918c0fb50be7267ef33a40c --- /dev/null +++ b/mindspore/explainer/ood/ood_net.py @@ -0,0 +1,223 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Out Of Distribution Network.""" + +import mindspore as ms +from mindspore import context +from mindspore import nn +from mindspore import ops +from mindspore.common.initializer import HeNormal +from mindspore.train.callback import Callback +from mindspore.train.callback import LearningRateScheduler + + +class OODUnderlying(nn.Cell): + """The base class of underlying classifier.""" + + @property + def feature_count(self): + """ + The Number of features. + + Returns: + int, the number of features. + """ + raise NotImplementedError + + def construct_feature(self, x): + """ + Forward inference features. + + Returns: + Tensor, feature tensor in the shape of [batch_size, feature_count] + """ + return None + + +class OODNet(nn.Cell): + """ + Out Of distribution network. + + Args: + underlying (OODUnderlying, optional): The underlying classifier. None means using OODResNet50 as underlying. + num_classes (int): Number of classes for the classifier. + """ + + def __init__(self, underlying, num_classes): + super(OODNet, self).__init__() + if underlying is None: + from mindspore.explainer.ood.ood_resnet import OODResNet50 + self._underlying = OODResNet50(num_classes) + self._train_partial = False + else: + self._underlying = underlying + self._train_partial = True + self._h = nn.Dense(in_channels=self._underlying.feature_count, + out_channels=num_classes, + has_bias=False, + weight_init=HeNormal(nonlinearity='relu')) + if context.get_context('device_target') == 'GPU': + # BatchNorm1d is not working on GPU + self._g = nn.SequentialCell( + nn.Dense(in_channels=self._underlying.feature_count, out_channels=1), + nn.Sigmoid() + ) + else: + self._g = nn.SequentialCell( + nn.Dense(in_channels=self._underlying.feature_count, out_channels=1), + nn.BatchNorm1d(num_features=1), + nn.Sigmoid() + ) + self._matmul_weight = nn.MatMul(transpose_x1=False, transpose_x2=True) + self._norm = nn.Norm(axis=(1,)) + self._transpose = ops.Transpose() + self._feature_count = self._underlying.feature_count + self._tile = ops.Tile() + self._is_train = False + self.set_train(False) + + def set_train(self, mode=True): + """ + Set training mode. + + Args: + mode (bool): It is in training mode. + """ + super(OODNet, self).set_train(mode) + self._is_train = mode + + def construct(self, x): + """ + Forward inferences the classification logits or OOD scores. + + Returns: + Tensor, classification logits (if set_train(True) was called) or + OOD scores (if set_train(False) was called). In the shape of [batch_size, num_classes]. + """ + feat = self._underlying.construct_feature(x) + scores = self._ood_scores(feat) + if self._is_train: + denorm = self._g(feat) + logits = scores / denorm + return logits + return scores + + def _ood_scores(self, feat): + """Forward inferences the OOD scores.""" + norm_f = self._normalize(feat) + norm_w = self._normalize(self._h.weight) + scores = self._matmul_weight(norm_f, norm_w) + return scores + + def _normalize(self, x): + """Normalizes an tensor.""" + norm = self._norm(x) + tiled_norm = self._tile((norm + 1e-4), (self._feature_count, 1)) + tiled_norm = self._transpose(tiled_norm, (1, 0)) + x = x / tiled_norm + return x + + def prepare_train(self, + multi_label, + learning_rate=0.1, + momentum=0.9, + weight_decay=0.0001, + lr_base_factor=0.1, + lr_epoch_denom=30): + """ + Creates necessities for training. + + Args: + multi_label (bool): Samples are labeled into multiple classes. + learning_rate (float): The optimizer learning rate. + momentum (float): The optimizer momentum. + weight_decay (float): The optimizer weight decay. + lr_base_factor (float): The base scaling factor of learning rate scheduler. + lr_epoch_denom (int): The epoch denominator of learning rate scheduler. + + Returns: + tuple[Cell, Optimizer, LearningRateScheduler], the loss function, optimizer and learning rate scheduler. + """ + if self._train_partial: + parameters = [] + parameters.extend(self._h.get_parameters()) + parameters.extend(self._g.get_parameters()) + else: + parameters = list(self.get_parameters()) + scheduler = _EpochLrScheduler(learning_rate, lr_base_factor, lr_epoch_denom) + optimizer = nn.SGD(parameters, learning_rate=learning_rate, momentum=momentum, weight_decay=weight_decay) + if multi_label: + loss_fn = nn.BCELoss() + else: + loss_fn = nn.SoftmaxCrossEntropyWithLogits() + return loss_fn, optimizer, scheduler + + def train(self, + dataset, + callbacks=None, + epoch=90, + **kwargs): + """ + Trains this OOD net. + + Args: + dataset (Dataset): The training dataset, expecting (data, one-hot label) items. + callbacks (Callback, optional): The train callbacks. + epoch (int): The number of epochs to be trained. + **kwargs (any): Keyword arguments for prepare_train(). + """ + self.set_train(True) + loss_fn, optimizer, scheduler = self.prepare_train(**kwargs) + model = ms.Model(self, loss_fn=loss_fn, optimizer=optimizer) + if callbacks is None: + callbacks = [scheduler] + elif isinstance(callbacks, list): + callbacks_ = [scheduler] + callbacks_.extend(callbacks) + callbacks = callbacks_ + elif isinstance(callbacks, Callback): + callbacks = [scheduler, callbacks] + else: + raise ValueError('invalid callbacks type') + model.train(epoch, dataset, callbacks=callbacks) + self.set_train(False) + + +class _EpochLrScheduler(LearningRateScheduler): + """ + Epoch based learning rate scheduler. + + Args: + base_lr (float): The base learning rate. + base_factor (float): The base scaling factor. + denominator (int): The epoch denominator. + """ + def __init__(self, base_lr, base_factor, denominator): + super(_EpochLrScheduler, self).__init__(self._lr_function) + self.base_lr = base_lr + self.base_factor = base_factor + self.denominator = denominator + self._cur_epoch_num = 1 + + def epoch_end(self, run_context): + """On an epoch was ended.""" + cb_params = run_context.original_args() + self._cur_epoch_num = cb_params.cur_epoch_num + + def _lr_function(self, lr, cur_step_num): + """Returns the dynamic learning rate.""" + del lr + del cur_step_num + return self.base_lr * (self.base_factor ** (self._cur_epoch_num // self.denominator)) diff --git a/mindspore/explainer/ood/ood_resnet.py b/mindspore/explainer/ood/ood_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..65b7747e56df90f5f8bce2a1260c1c9914ea3c30 --- /dev/null +++ b/mindspore/explainer/ood/ood_resnet.py @@ -0,0 +1,343 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ResNet for Out Of Distribution Network.""" +import numpy as np +from scipy.stats import truncnorm + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.tensor import Tensor +from mindspore.explainer.ood.ood_net import OODUnderlying + + +def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size): + fan_in = in_channel * kernel_size * kernel_size + scale = 1.0 + scale /= max(1., fan_in) + stddev = (scale ** 0.5) / .87962566103423978 + mu, sigma = 0, stddev + weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size) + weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size)) + return Tensor(weight, dtype=mstype.float32) + + +def _weight_variable(shape, factor=0.01): + init_value = np.random.randn(*shape).astype(np.float32) * factor + return Tensor(init_value) + + +def _conv3x3(in_channel, out_channel, stride=1, use_se=False): + if use_se: + weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3) + else: + weight_shape = (out_channel, in_channel, 3, 3) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv1x1(in_channel, out_channel, stride=1, use_se=False): + if use_se: + weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1) + else: + weight_shape = (out_channel, in_channel, 1, 1) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv7x7(in_channel, out_channel, stride=1, use_se=False): + if use_se: + weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7) + else: + weight_shape = (out_channel, in_channel, 7, 7) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _bn_last(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _fc(in_channel, out_channel, use_se=False): + if use_se: + weight = np.random.normal(loc=0, scale=0.01, size=out_channel*in_channel) + weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32) + else: + weight_shape = (out_channel, in_channel) + weight = _weight_variable(weight_shape) + return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) + + +class ResidualBlock(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + use_se (bool): enable SE-ResNet50 net. Default: False. + se_block(bool): use se block in SE-ResNet50 net. Default: False. + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, + in_channel, + out_channel, + stride=1, + use_se=False, se_block=False): + super(ResidualBlock, self).__init__() + self.stride = stride + self.use_se = use_se + self.se_block = se_block + channel = out_channel // self.expansion + self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se) + self.bn1 = _bn(channel) + if self.use_se and self.stride != 1: + self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel), + nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')]) + else: + self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se) + self.bn2 = _bn(channel) + + self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se) + self.bn3 = _bn_last(out_channel) + if self.se_block: + self.se_global_pool = P.ReduceMean(keep_dims=False) + self.se_dense_0 = _fc(out_channel, int(out_channel/4), use_se=self.use_se) + self.se_dense_1 = _fc(int(out_channel/4), out_channel, use_se=self.use_se) + self.se_sigmoid = nn.Sigmoid() + self.se_mul = P.Mul() + self.relu = nn.ReLU() + + self.down_sample = False + + if stride != 1 or in_channel != out_channel: + self.down_sample = True + self.down_sample_layer = None + + if self.down_sample: + if self.use_se: + if stride == 1: + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, + stride, use_se=self.use_se), _bn(out_channel)]) + else: + self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'), + _conv1x1(in_channel, out_channel, 1, + use_se=self.use_se), _bn(out_channel)]) + else: + self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride, + use_se=self.use_se), _bn(out_channel)]) + self.add = P.Add() + + def construct(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + if self.use_se and self.stride != 1: + out = self.e2(out) + else: + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + if self.se_block: + out_se = out + out = self.se_global_pool(out, (2, 3)) + out = self.se_dense_0(out) + out = self.relu(out) + out = self.se_dense_1(out) + out = self.se_sigmoid(out) + out = F.reshape(out, F.shape(out) + (1, 1)) + out = self.se_mul(out, out_se) + + if self.down_sample: + identity = self.down_sample_layer(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class OODResNet(OODUnderlying): + """ + ResNet architecture. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of block in different layers. + in_channels (list): Input channel in each layer. + out_channels (list): Output channel in each layer. + strides (list): Stride size in each layer. + num_classes (int): The number of classes that the training images are belonging to. + use_se (bool): enable SE-ResNet50 net. Default: False. + Returns: + Tensor, output tensor. + """ + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides, + num_classes, + use_se=False): + super(OODResNet, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") + self.use_se = use_se + self.se_block = False + if self.use_se: + self.se_block = True + + if self.use_se: + self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se) + self.bn1_0 = _bn(32) + self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se) + self.bn1_1 = _bn(32) + self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se) + else: + self.conv1 = _conv7x7(3, 64, stride=2) + self.bn1 = _bn(64) + self.relu = P.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0], + use_se=self.use_se) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1], + use_se=self.use_se) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2], + use_se=self.use_se, + se_block=self.se_block) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3], + use_se=self.use_se, + se_block=self.se_block) + + self.mean = P.ReduceMean(keep_dims=True) + self.flatten = nn.Flatten() + self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se) + + self._feature_count = out_channels[3] + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + se_block(bool): use se block in SE-ResNet50 net. Default: False. + Returns: + SequentialCell, the output layer. + """ + layers = [] + + resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se) + layers.append(resnet_block) + if se_block: + for _ in range(1, layer_num - 1): + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) + layers.append(resnet_block) + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block) + layers.append(resnet_block) + else: + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se) + layers.append(resnet_block) + return nn.SequentialCell(layers) + + @property + def feature_count(self): + return self._feature_count + + def construct_feature(self, x): + if self.use_se: + x = self.conv1_0(x) + x = self.bn1_0(x) + x = self.relu(x) + x = self.conv1_1(x) + x = self.bn1_1(x) + x = self.relu(x) + x = self.conv1_2(x) + else: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + c1 = self.maxpool(x) + + c2 = self.layer1(c1) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + + out = self.mean(c5, (2, 3)) + out = self.flatten(out) + return out + + def construct(self, x): + out = self.construct_feature(x) + out = self.end_point(out) + return out + + +class OODResNet50(OODResNet): + """OOD underlying classifier in ResNet50 architecture.""" + def __init__(self, num_classes): + super(OODResNet50, self).__init__(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + num_classes) diff --git a/mindspore/explainer/ood/test.py b/mindspore/explainer/ood/test.py new file mode 100644 index 0000000000000000000000000000000000000000..0a7f78f0be7993f4a19a98cdbde37ad5138842df --- /dev/null +++ b/mindspore/explainer/ood/test.py @@ -0,0 +1,63 @@ +import numpy as np +import mindspore as ms +import mindspore.dataset as de +from mindspore.explainer.ood.ood_net import OODNet +from mindspore.explainer.ood.ood_resnet import OODResNet50 + +from mindspore.train.callback import Callback + +num_classes = 10 +num_samples = 12800 +batch_size = 64 +image_size = 224 +channel = 3 + + +class Print_info(Callback): + def step_end(self, run_context): + cb_params = run_context.original_args() + print(f"epoch {cb_params.cur_epoch_num} step {cb_params.cur_step_num}") + + def epoch_end(self, _): + print("epoch_end") + + +def test_infer(): + + classifier = OODResNet50(num_classes) + ood_net = OODNet(classifier, num_classes) + ood_net.set_train(False) + + batch_x = ms.Tensor(np.random.random((1, 3, image_size, image_size)), dtype=ms.float32) + ood_scores = ood_net(batch_x) + print(f'ood_scores.shape {ood_scores.shape}') + + +def ds_generator(): + for i in range(num_samples): + image = np.random.random((channel, image_size, image_size)).astype(np.float32) + labels = np.random.randint(0, num_classes, 3) + one_hot = np.zeros(num_classes, dtype=np.float32) + for label in labels: + one_hot[label] = 1.0 + yield image, one_hot + + +def test_train(): + + ds = de.GeneratorDataset(source=ds_generator, num_samples=num_samples, + column_names=['data', 'label'], column_types=[ms.float32, ms.float32]) + ds = ds.batch(batch_size) + ds.dataset_size = int(num_samples / batch_size) + + classifier = OODResNet50(num_classes) + ood_net = OODNet(classifier, num_classes) + ood_net.train(ds, callbacks=Print_info(), epoch=60, multi_label=True) + + batch_x = ms.Tensor(np.random.random((1, channel, 224, 224)), dtype=ms.float32) + ood_scores = ood_net(batch_x) + print(f'ood_scores.shape {ood_scores.shape}') + + +if __name__ == "__main__": + test_train() diff --git a/mindspore/train/callback/_lr_scheduler_callback.py b/mindspore/train/callback/_lr_scheduler_callback.py index 18608556ba3c6a491439d78a528387bbc6ab08b6..cc43164c35ce333415b4005f97fc3aa16e5cd0c1 100644 --- a/mindspore/train/callback/_lr_scheduler_callback.py +++ b/mindspore/train/callback/_lr_scheduler_callback.py @@ -22,6 +22,7 @@ from mindspore.common.tensor import Tensor from mindspore.train.callback._callback import Callback from mindspore.ops import functional as F + class LearningRateScheduler(Callback): """ Change the learning_rate during training.