From f6f80fae6431196adb38b5f5ace303e4c533fab2 Mon Sep 17 00:00:00 2001 From: luopengting Date: Sat, 23 May 2020 14:39:14 +0800 Subject: [PATCH] fix calculation of lineage dataset_num, get model name --- .../collection/model/model_lineage.py | 57 +++++++++---------- mindinsight/lineagemgr/querier/query_model.py | 1 - .../st/func/lineagemgr/api/test_model_api.py | 10 ++-- .../collection/model/test_model_lineage.py | 6 +- .../mindspore/dataset/engine/datasets.py | 4 ++ 5 files changed, 41 insertions(+), 37 deletions(-) diff --git a/mindinsight/lineagemgr/collection/model/model_lineage.py b/mindinsight/lineagemgr/collection/model/model_lineage.py index eeb5b89c..966a59b7 100644 --- a/mindinsight/lineagemgr/collection/model/model_lineage.py +++ b/mindinsight/lineagemgr/collection/model/model_lineage.py @@ -37,7 +37,7 @@ from mindinsight.lineagemgr.collection.model.base import Metadata try: from mindspore.common.tensor import Tensor from mindspore.train.callback import Callback, RunContext, ModelCheckpoint, SummaryStep - from mindspore.nn import Cell, Optimizer, WithLossCell, TrainOneStepWithLossScaleCell + from mindspore.nn import Cell, Optimizer from mindspore.nn.loss.loss import _Loss from mindspore.dataset.engine import Dataset, MindDataset import mindspore.dataset as ds @@ -412,27 +412,28 @@ class AnalyzeObject: Returns: str, the name of the backbone network. """ - with_loss_cell = False - backbone = None + backbone_name = None + has_network = False + network_key = 'network' + backbone_key = '_backbone' + net_args = vars(network) if network else {} net_cell = net_args.get('_cells') if net_args else {} - for _, value in net_cell.items(): - if isinstance(value, WithLossCell): - backbone = getattr(value, '_backbone') - with_loss_cell = True + for key, value in net_cell.items(): + if key == network_key: + network = value + has_network = True break - if with_loss_cell: - backbone_name = type(backbone).__name__ \ - if backbone else None - elif isinstance(network, TrainOneStepWithLossScaleCell): - backbone = getattr(network, 'network') - backbone_name = type(backbone).__name__ \ - if backbone else None - else: - backbone_name = type(network).__name__ \ - if network else None + if has_network: + while hasattr(network, network_key): + network = getattr(network, network_key) + if hasattr(network, backbone_key): + backbone = getattr(network, backbone_key) + backbone_name = type(backbone).__name__ + elif network is not None: + backbone_name = type(network).__name__ return backbone_name @staticmethod @@ -489,26 +490,24 @@ class AnalyzeObject: Returns: dict, the lineage metadata. """ - dataset_batch_size = dataset.get_dataset_size() - if dataset_batch_size is not None: - validate_int_params(dataset_batch_size, 'dataset_batch_size') - log.debug('dataset_batch_size: %d', dataset_batch_size) + batch_num = dataset.get_dataset_size() + batch_size = dataset.get_batch_size() + if batch_num is not None: + validate_int_params(batch_num, 'dataset_batch_num') + validate_int_params(batch_num, 'dataset_batch_size') + log.debug('dataset_batch_num: %d', batch_num) + log.debug('dataset_batch_size: %d', batch_size) dataset_path = AnalyzeObject.get_dataset_path_wrapped(dataset) if dataset_path: dataset_path = '/'.join(dataset_path.split('/')[:-1]) - step_num = lineage_dict.get('step_num') - validate_int_params(step_num, 'step_num') - log.debug('step_num: %d', step_num) - + dataset_size = int(batch_num * batch_size) if dataset_type == 'train': lineage_dict[Metadata.train_dataset_path] = dataset_path - epoch = lineage_dict.get('epoch') - train_dataset_size = dataset_batch_size * (step_num / epoch) - lineage_dict[Metadata.train_dataset_size] = int(train_dataset_size) + lineage_dict[Metadata.train_dataset_size] = dataset_size elif dataset_type == 'valid': lineage_dict[Metadata.valid_dataset_path] = dataset_path - lineage_dict[Metadata.valid_dataset_size] = dataset_batch_size * step_num + lineage_dict[Metadata.valid_dataset_size] = dataset_size return lineage_dict diff --git a/mindinsight/lineagemgr/querier/query_model.py b/mindinsight/lineagemgr/querier/query_model.py index 74051c99..007c2d4c 100644 --- a/mindinsight/lineagemgr/querier/query_model.py +++ b/mindinsight/lineagemgr/querier/query_model.py @@ -82,7 +82,6 @@ class LineageObj: self._lineage_info = { self._name_summary_dir: summary_dir } - self._filtration_result = None self._init_lineage() self.parse_and_update_lineage(**kwargs) diff --git a/tests/st/func/lineagemgr/api/test_model_api.py b/tests/st/func/lineagemgr/api/test_model_api.py index 2357c440..c8aea9db 100644 --- a/tests/st/func/lineagemgr/api/test_model_api.py +++ b/tests/st/func/lineagemgr/api/test_model_api.py @@ -50,10 +50,10 @@ LINEAGE_INFO_RUN1 = { 'network': 'ResNet' }, 'train_dataset': { - 'train_dataset_size': 731 + 'train_dataset_size': 1024 }, 'valid_dataset': { - 'valid_dataset_size': 10240 + 'valid_dataset_size': 1024 }, 'model': { 'path': '{"ckpt": "' @@ -89,9 +89,9 @@ LINEAGE_FILTRATION_RUN1 = { 'model_lineage': { 'loss_function': 'SoftmaxCrossEntropyWithLogits', 'train_dataset_path': None, - 'train_dataset_count': 731, + 'train_dataset_count': 1024, 'test_dataset_path': None, - 'test_dataset_count': 10240, + 'test_dataset_count': 1024, 'user_defined': {}, 'network': 'ResNet', 'optimizer': 'Momentum', @@ -115,7 +115,7 @@ LINEAGE_FILTRATION_RUN2 = { 'train_dataset_path': None, 'train_dataset_count': 1024, 'test_dataset_path': None, - 'test_dataset_count': 10240, + 'test_dataset_count': 1024, 'user_defined': {}, 'network': "ResNet", 'optimizer': "Momentum", diff --git a/tests/ut/lineagemgr/collection/model/test_model_lineage.py b/tests/ut/lineagemgr/collection/model/test_model_lineage.py index e6857f72..81a61ea6 100644 --- a/tests/ut/lineagemgr/collection/model/test_model_lineage.py +++ b/tests/ut/lineagemgr/collection/model/test_model_lineage.py @@ -334,12 +334,14 @@ class TestAnalyzer(TestCase): ) res1 = self.analyzer.analyze_dataset(dataset, {'step_num': 10, 'epoch': 2}, 'train') res2 = self.analyzer.analyze_dataset(dataset, {'step_num': 5}, 'valid') + + # batch_size is mocked as 32. assert res1 == {'step_num': 10, 'train_dataset_path': '/path/to', - 'train_dataset_size': 50, + 'train_dataset_size': 320, 'epoch': 2} assert res2 == {'step_num': 5, 'valid_dataset_path': '/path/to', - 'valid_dataset_size': 50} + 'valid_dataset_size': 320} def test_get_dataset_path_dataset(self): """Test get_dataset_path method with Dataset.""" diff --git a/tests/utils/mindspore/dataset/engine/datasets.py b/tests/utils/mindspore/dataset/engine/datasets.py index a9decbce..8fc28967 100644 --- a/tests/utils/mindspore/dataset/engine/datasets.py +++ b/tests/utils/mindspore/dataset/engine/datasets.py @@ -27,6 +27,10 @@ class Dataset: """Mocked get_dataset_size.""" return self.dataset_size + def get_batch_size(self): + """Mocked get_batch_size""" + return 32 + class MindDataset(Dataset): """Mock the MindSpore MindDataset class.""" -- Gitee