From 32d1318894c35e890b795c837f9025cace9d2447 Mon Sep 17 00:00:00 2001 From: TonyNG Date: Wed, 13 Jul 2022 16:36:07 +0800 Subject: [PATCH] add limit tabdig bins arg --- mindspore_xai/tool/cli.py | 5 +++-- mindspore_xai/tool/tab/tab_sim.py | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mindspore_xai/tool/cli.py b/mindspore_xai/tool/cli.py index c5124c7..590a573 100644 --- a/mindspore_xai/tool/cli.py +++ b/mindspore_xai/tool/cli.py @@ -33,8 +33,9 @@ def cli_entry(): help='path of the real CSV table to be simulated.') parser_tabdig.add_argument(type=str, dest='digest_file', help='path of the digest file to be saved.') - parser_tabdig.add_argument('--bins', type=int, dest='num_bins', required=False, default=10, - help='[optional] number of bins for discretizing numeric columns, default: 10') + parser_tabdig.add_argument('--bins', type=int, dest='num_bins', required=False, choices=range(2, 33), + default=10, metavar="[2-32]", + help='[optional] number of bins (2-32) for discretizing numeric columns, default: 10') parser_tabdig.add_argument('--clip-sd', type=float, dest='clip_sd', required=False, default=3, help='[optional] number of standard deviations away from the mean that defines the ' 'outliers, outlier values will be clipped. default: 3, setting to 0 or less will ' diff --git a/mindspore_xai/tool/tab/tab_sim.py b/mindspore_xai/tool/tab/tab_sim.py index 0ce41a0..0e4db4e 100644 --- a/mindspore_xai/tool/tab/tab_sim.py +++ b/mindspore_xai/tool/tab/tab_sim.py @@ -27,7 +27,7 @@ _EPS = 1e-9 # max. no. of bins in a column group # group no. of bins = product of all member column no. of bins -_COL_GRP_MAX_BIN = 1000 +_COL_GRP_MAX_BIN = 1024 # min. information quality ratio in a column grouping _COL_GRP_MIN_IQR = 0.1 @@ -329,7 +329,7 @@ class CsvTabDigest(TabDigest): allowed. Args: - num_bins (int): Number of bins for numeric columns. Default: 10. + num_bins (int): Number of bins for numeric columns, must be in range of :math:`[2, 32]`. Default: 10. clip_sd (int, float): Number of standard deviations for clipping numeric column values. Disable the clipping by providing zero. Default: 3. @@ -340,6 +340,8 @@ class CsvTabDigest(TabDigest): def __init__(self, num_bins=10, clip_sd=3): if not isinstance(num_bins, int): raise TypeError(f'Argument "num_bins" must be in type of int.') + if num_bins < 2 or num_bins > 32: + raise ValueError(f'Argument "num_bins" must be in range of [2, 32], but got {num_bins}.') if not isinstance(clip_sd, (int, float)): raise TypeError(f'Argument "clip_sd" must be in type of int or float.') super().__init__() -- Gitee