diff --git a/mindspore/explainer/ood/ood_net.py b/mindspore/explainer/ood/ood_net.py index d927eb9ee57456bb9918c0fb50be7267ef33a40c..63d0875ce2c3c08335429ae6949fd75e96bdf0a2 100644 --- a/mindspore/explainer/ood/ood_net.py +++ b/mindspore/explainer/ood/ood_net.py @@ -15,7 +15,6 @@ """Out Of Distribution Network.""" import mindspore as ms -from mindspore import context from mindspore import nn from mindspore import ops from mindspore.common.initializer import HeNormal @@ -38,7 +37,7 @@ class OODUnderlying(nn.Cell): def construct_feature(self, x): """ - Forward inference features. + Forward inferences features. Returns: Tensor, feature tensor in the shape of [batch_size, feature_count] @@ -53,6 +52,10 @@ class OODNet(nn.Cell): Args: underlying (OODUnderlying, optional): The underlying classifier. None means using OODResNet50 as underlying. num_classes (int): Number of classes for the classifier. + + Returns: + Tensor, classification logits (if set_train(True) was called) or + OOD scores (if set_train(False) was called). In the shape of [batch_size, num_classes]. """ def __init__(self, underlying, num_classes): @@ -68,18 +71,14 @@ class OODNet(nn.Cell): out_channels=num_classes, has_bias=False, weight_init=HeNormal(nonlinearity='relu')) - if context.get_context('device_target') == 'GPU': - # BatchNorm1d is not working on GPU - self._g = nn.SequentialCell( - nn.Dense(in_channels=self._underlying.feature_count, out_channels=1), - nn.Sigmoid() - ) - else: - self._g = nn.SequentialCell( - nn.Dense(in_channels=self._underlying.feature_count, out_channels=1), - nn.BatchNorm1d(num_features=1), - nn.Sigmoid() - ) + + # BatchNorm1d is not working on GPU, workaround with BatchNorm2d + self._expand_dims = ops.ExpandDims() + self._g_fc = nn.Dense(in_channels=self._underlying.feature_count, out_channels=1) + self._g_bn2d = nn.BatchNorm2d(num_features=1) + self._g_squeeze = ops.Squeeze(axis=2) + self._g_sigmoid = nn.Sigmoid() + self._matmul_weight = nn.MatMul(transpose_x1=False, transpose_x2=True) self._norm = nn.Norm(axis=(1,)) self._transpose = ops.Transpose() @@ -109,26 +108,18 @@ class OODNet(nn.Cell): feat = self._underlying.construct_feature(x) scores = self._ood_scores(feat) if self._is_train: - denorm = self._g(feat) + feat = self._g_fc(feat) + feat = self._expand_dims(feat, 2) + feat = self._expand_dims(feat, 2) + feat = self._g_bn2d(feat) + feat = self._g_squeeze(feat) + feat = self._g_squeeze(feat) + denorm = self._g_sigmoid(feat) logits = scores / denorm return logits - return scores - def _ood_scores(self, feat): - """Forward inferences the OOD scores.""" - norm_f = self._normalize(feat) - norm_w = self._normalize(self._h.weight) - scores = self._matmul_weight(norm_f, norm_w) return scores - def _normalize(self, x): - """Normalizes an tensor.""" - norm = self._norm(x) - tiled_norm = self._tile((norm + 1e-4), (self._feature_count, 1)) - tiled_norm = self._transpose(tiled_norm, (1, 0)) - x = x / tiled_norm - return x - def prepare_train(self, multi_label, learning_rate=0.1, @@ -140,7 +131,7 @@ class OODNet(nn.Cell): Creates necessities for training. Args: - multi_label (bool): Samples are labeled into multiple classes. + multi_label (bool): Samples are labeled with multiple classes. learning_rate (float): The optimizer learning rate. momentum (float): The optimizer momentum. weight_decay (float): The optimizer weight decay. @@ -153,7 +144,7 @@ class OODNet(nn.Cell): if self._train_partial: parameters = [] parameters.extend(self._h.get_parameters()) - parameters.extend(self._g.get_parameters()) + parameters.extend(self._g_fc.get_parameters()) else: parameters = list(self.get_parameters()) scheduler = _EpochLrScheduler(learning_rate, lr_base_factor, lr_epoch_denom) @@ -194,6 +185,21 @@ class OODNet(nn.Cell): model.train(epoch, dataset, callbacks=callbacks) self.set_train(False) + def _ood_scores(self, feat): + """Forward inferences the OOD scores.""" + norm_f = self._normalize(feat) + norm_w = self._normalize(self._h.weight) + scores = self._matmul_weight(norm_f, norm_w) + return scores + + def _normalize(self, x): + """Normalizes an tensor.""" + norm = self._norm(x) + tiled_norm = self._tile((norm + 1e-4), (self._feature_count, 1)) + tiled_norm = self._transpose(tiled_norm, (1, 0)) + x = x / tiled_norm + return x + class _EpochLrScheduler(LearningRateScheduler): """ diff --git a/mindspore/explainer/ood/ood_resnet.py b/mindspore/explainer/ood/ood_resnet.py index 65b7747e56df90f5f8bce2a1260c1c9914ea3c30..44c03f0891954543d3ac6dd944b6012cd59809d2 100644 --- a/mindspore/explainer/ood/ood_resnet.py +++ b/mindspore/explainer/ood/ood_resnet.py @@ -203,6 +203,7 @@ class OODResNet(OODUnderlying): strides (list): Stride size in each layer. num_classes (int): The number of classes that the training images are belonging to. use_se (bool): enable SE-ResNet50 net. Default: False. + Returns: Tensor, output tensor. """ diff --git a/mindspore/explainer/ood/test.py b/mindspore/explainer/ood/test.py index 0a7f78f0be7993f4a19a98cdbde37ad5138842df..f656a0c1fd60a1c88a32345eb5987415828f0d8d 100644 --- a/mindspore/explainer/ood/test.py +++ b/mindspore/explainer/ood/test.py @@ -7,7 +7,7 @@ from mindspore.explainer.ood.ood_resnet import OODResNet50 from mindspore.train.callback import Callback num_classes = 10 -num_samples = 12800 +num_samples = 128 batch_size = 64 image_size = 224 channel = 3 @@ -30,7 +30,7 @@ def test_infer(): batch_x = ms.Tensor(np.random.random((1, 3, image_size, image_size)), dtype=ms.float32) ood_scores = ood_net(batch_x) - print(f'ood_scores.shape {ood_scores.shape}') + print(f'ood_scores: {ood_scores}') def ds_generator(): diff --git a/mindspore/train/callback/_lr_scheduler_callback.py b/mindspore/train/callback/_lr_scheduler_callback.py index cc43164c35ce333415b4005f97fc3aa16e5cd0c1..18608556ba3c6a491439d78a528387bbc6ab08b6 100644 --- a/mindspore/train/callback/_lr_scheduler_callback.py +++ b/mindspore/train/callback/_lr_scheduler_callback.py @@ -22,7 +22,6 @@ from mindspore.common.tensor import Tensor from mindspore.train.callback._callback import Callback from mindspore.ops import functional as F - class LearningRateScheduler(Callback): """ Change the learning_rate during training.