From 2171bc5e978a5d072ffffafbfa65f37152eeacb4 Mon Sep 17 00:00:00 2001 From: ttaohe828 Date: Sun, 29 Sep 2024 22:38:20 +0800 Subject: [PATCH] add yolo-world mindspore --- community/cv/yolo-world/README_CN.md | 90 ++++ community/cv/yolo-world/requirements.txt | 17 + community/cv/yolo-world/yolow/__init__.py | 0 community/cv/yolo-world/yolow/logger.py | 331 ++++++++++++ .../cv/yolo-world/yolow/model/__init__.py | 13 + .../cv/yolo-world/yolow/model/clip_text.py | 90 ++++ .../yolow/model/data_preprocessor.py | 129 +++++ .../cv/yolo-world/yolow/model/factory.py | 337 ++++++++++++ .../yolo-world/yolow/model/layers/__init__.py | 16 + .../cv/yolo-world/yolow/model/layers/attn.py | 94 ++++ .../yolow/model/layers/bottleneck.py | 109 ++++ .../cv/yolo-world/yolow/model/layers/conv.py | 94 ++++ .../yolow/model/layers/csp_layer.py | 104 ++++ community/cv/yolo-world/yolow/model/misc.py | 332 ++++++++++++ .../yolow/model/model_cfgs/yoloworld_l.json | 37 ++ .../yolow/model/model_cfgs/yoloworld_m.json | 37 ++ .../yolow/model/model_cfgs/yoloworld_n.json | 37 ++ .../yolow/model/model_cfgs/yoloworld_s.json | 37 ++ .../yolow/model/model_cfgs/yoloworld_x.json | 37 ++ .../yolow/model/model_cfgs/yoloworld_xl.json | 37 ++ .../yolow/model/task_utils/__init__.py | 11 + .../task_utils/distance_point_bbox_coder.py | 108 ++++ .../yolow/model/task_utils/point_generator.py | 229 ++++++++ .../cv/yolo-world/yolow/model/yolo_base.py | 147 +++++ .../cv/yolo-world/yolow/model/yolo_world.py | 144 +++++ .../yolow/model/yolo_world_backbone.py | 67 +++ .../yolo-world/yolow/model/yolo_world_head.py | 502 ++++++++++++++++++ .../yolow/model/yolo_world_pafpn.py | 187 +++++++ 28 files changed, 3373 insertions(+) create mode 100644 community/cv/yolo-world/README_CN.md create mode 100644 community/cv/yolo-world/requirements.txt create mode 100644 community/cv/yolo-world/yolow/__init__.py create mode 100644 community/cv/yolo-world/yolow/logger.py create mode 100644 community/cv/yolo-world/yolow/model/__init__.py create mode 100644 community/cv/yolo-world/yolow/model/clip_text.py create mode 100644 community/cv/yolo-world/yolow/model/data_preprocessor.py create mode 100644 community/cv/yolo-world/yolow/model/factory.py create mode 100644 community/cv/yolo-world/yolow/model/layers/__init__.py create mode 100644 community/cv/yolo-world/yolow/model/layers/attn.py create mode 100644 community/cv/yolo-world/yolow/model/layers/bottleneck.py create mode 100644 community/cv/yolo-world/yolow/model/layers/conv.py create mode 100644 community/cv/yolo-world/yolow/model/layers/csp_layer.py create mode 100644 community/cv/yolo-world/yolow/model/misc.py create mode 100644 community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_l.json create mode 100644 community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_m.json create mode 100644 community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_n.json create mode 100644 community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_s.json create mode 100644 community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_x.json create mode 100644 community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_xl.json create mode 100644 community/cv/yolo-world/yolow/model/task_utils/__init__.py create mode 100644 community/cv/yolo-world/yolow/model/task_utils/distance_point_bbox_coder.py create mode 100644 community/cv/yolo-world/yolow/model/task_utils/point_generator.py create mode 100644 community/cv/yolo-world/yolow/model/yolo_base.py create mode 100644 community/cv/yolo-world/yolow/model/yolo_world.py create mode 100644 community/cv/yolo-world/yolow/model/yolo_world_backbone.py create mode 100644 community/cv/yolo-world/yolow/model/yolo_world_head.py create mode 100644 community/cv/yolo-world/yolow/model/yolo_world_pafpn.py diff --git a/community/cv/yolo-world/README_CN.md b/community/cv/yolo-world/README_CN.md new file mode 100644 index 000000000..ee7f82a8d --- /dev/null +++ b/community/cv/yolo-world/README_CN.md @@ -0,0 +1,90 @@ + +# 目录 + +- [目录](#目录) +- [项目说明](#项目说明) + - [项目简介](#项目简介) + - [项目意义](#项目意义) +- [快速入门](#快速入门) +- [文件说明](#文件说明) + - [脚本及代码](#脚本及代码) +- [测试和推理过程](#测试和推理过程) +- [性能](#性能) + +# [项目说明](#目录) + +## [项目简介](#目录) + +**本项目已在Ai Gallery上架notebook脚本,可以查阅[代码](https://developer.huaweicloud.com/develop/aigallery/notebook/detail?id=9269fd3b-2235-403f-a764-87a0e27cd654)来进行体验。** +作为最基本的场景理解任务,目标检测和图像分割在深度学习时代取得了巨大的进步。但由于昂贵的人工标记成本,现有数据集中的标注类别通常是小规模和预定义的,即使最先进的完全监督的检测模型和分割模型无法超越封闭的类别词汇表进行推广。例如Pascal VOC中的20个类,COCO中的80个类,甚至是类别最广泛的LVIS数据集也只提供了1203个类,现有的所有基于CNN或者Transformer的模型都能在局限的语义类别(标签)取得稳定的检测性能。相反人类感知系统可以将任意的视觉概念与开放式类名或自然语言描述联系起来,封闭语义集极大限制了目标检测和图像分割两大高级视觉任务在现实中的应用。 + +## [项目意义](#目录) + +YOLO-World基于先进的目标检测实时方案YOLOv8,根据描述性文本检测图像中的任何物体。相较于以往的开放词汇目标检测模型,不需要大量计算资源的Transformer模型,而利用CNN的计算速度,采用视觉语言建模和在大量数据集上进行预训练的方法,可提供快速的开放词汇目标检测,满足各行业对即时结果的需求。 + +# [快速入门](#目录) + +可通过以下命令进行安装 + +```bash +conda create -n yolo-world-v2 python=3.9 +conda activate yolo-world-v2 +pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.2.14/MindSpore/unified/x86_64/mindspore-2.2.14-cp39-cp39-linux_x86_64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple +cd community/cv/yolo-world +pip install -e . +``` + +# [文件说明](#目录) + +## [脚本及代码](#目录) + +```text +├── cv + ├── yolo-world + ├── yolow + ├── model + ├── layers + ├── model_cfgs + ├── task_utils + ├──logger.py + ├── README_CN.md + ├── requirements.txt +``` + +# [测试和推理过程](#目录) + +在val和minival数据集上测试。 +Test one model in val or minival set. + +```bash +bash scripts/test_single.sh val +bash scripts/test_single.sh minival +``` + +推理脚本。 + +```bash +python scripts/inference.py +``` + +# [性能](#目录) + +**我们比较了在LVIS-minival和LVIS-val数据集上测试的Mindspore模型与Torch模型的结果,并提供了以下结果,所有百度云(代码:ospp)链接均可下载。** + +## Zero-shot Inference on LVIS dataset + +We test our mindspore model on LVIS-minival and LVIS-val datasets compared with torch model, and we provide the results below, all BaiduYun(code: ospp) links are available for download. +
+ +|model|Size|APmini| |APr| |APc| |APf| |APval| |APr| |APc| |APf| | +|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----|:----| +| | |torch|mindspore|torch|mindspore|torch|mindspore|torch|mindspore|torch|mindspore|torch|mindspore|torch|mindspore|torch|mindspore| +|[YOLO-Worldv2-S](https://pan.baidu.com/s/1YaZN1H_zwkOUuPM1drXZRQ?pwd=ospp)|640|22.7|22.6|16.3|17.3|20.8|20.8|25.5|25.2|17.3|17.2|11.3|11.1|14.9|14.8|22.7|22.6| +|[YOLO-Worldv2-S](https://pan.baidu.com/s/1v2lzWHffBX9el7n3lcfhIA?pwd=ospp)|1280|24.1|23.8|18.7|17|22|21.9|26.9|26.6|18.8|18.7|14.1|14|16.3|16.1|23.8|23.6| +|[YOLO-Worldv2-M](https://pan.baidu.com/s/1Gv3hk8sk-Ipz74nE83D2GA?pwd=ospp)|640|30|29.9|25|25.5|27.2|26.8|33.4|33.4|23.5|23.5|17.1|17.4|20|20.2|30.1|30| +|[YOLO-Worldv2-M](https://pan.baidu.com/s/1X5ZkvzcfCUTXnPMVinA2dA?pwd=ospp)|1280|31.6|31.3|24.5|24.7|29|28.8|35.1|34.7|25.3|25.1|19.3|19.1|22|21.7|31.7|31.6| +|[YOLO-Worldv2-L](https://pan.baidu.com/s/15oMgAKsl48wUznJctBN6sw?pwd=ospp)|640|33|33.5|22.6|24.2|32|32.8|35.8|35.7|26|26|18.6|19|23|23|32.6|32.5| +|[YOLO-Worldv2-L](https://pan.baidu.com/s/1VQ6w1Q9z6MnJsY2J1aRtqA?pwd=ospp)|1280|34.6|34.7|29.2|29.3|32.8|33|37.2|37.1|27.6|27.5|21.9|22|24.2|24|34|33.8| +|[YOLO-Worldv2-X](https://pan.baidu.com/s/16eB9EAH1AHePD5d3n8OOlA?pwd=ospp)|640|35.4|34.9|28.7|27.6|32.9|32.4|38.7|38.5|28.4|28.3|20.6|21.2|25.6|25.3|35|34.8| +|[YOLO-Worldv2-X](https://pan.baidu.com/s/1wTo_6SGj51L0wrlwnYn9Ag?pwd=ospp)|1280|37.4|37.6|30.5|28.2|35.2|36.4|40.7|40.3|29.8|29.6|21.1|21.5|26.8|26.5|37|36.7| +|[YOLO-Worldv2-XL](https://pan.baidu.com/s/1Y3q_MqMXPlGp2R3brODOSg?pwd=ospp)|640|36|35.6|25.8|24.8|34.1|34.2|39.5|38.7|29.1|29|21.1|21.8|26.3|26.3|35.8|35.2| diff --git a/community/cv/yolo-world/requirements.txt b/community/cv/yolo-world/requirements.txt new file mode 100644 index 000000000..0e47a6d2e --- /dev/null +++ b/community/cv/yolo-world/requirements.txt @@ -0,0 +1,17 @@ +addict +albumentations +iopath +kiwisolver +opencv-python==4.8.1.78 +pycocotools +supervision +tabulate +yapf +numpy==1.24.4 +lvis +mindnlp==0.3.1 +troubleshooter==1.0.18 +kiwisolver==1.4.5 +matplotlib==3.7.5 +supervision==0.19.0 +pandas==2.0.3 \ No newline at end of file diff --git a/community/cv/yolo-world/yolow/__init__.py b/community/cv/yolo-world/yolow/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/community/cv/yolo-world/yolow/logger.py b/community/cv/yolo-world/yolow/logger.py new file mode 100644 index 000000000..650315fd4 --- /dev/null +++ b/community/cv/yolo-world/yolow/logger.py @@ -0,0 +1,331 @@ +""" +Filename: logger.py +Author: hetao +""" +import atexit +import datetime +import functools +import logging +import os +import sys +import time +from collections import Counter +from typing import Optional +from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler +from iopath.common.file_io import PathManager as PathManagerBase +from tabulate import tabulate +from termcolor import colored + + +__all__ = ['setup_logger', 'log_first_n', 'log_every_n', 'log_every_n_seconds'] + +PathManager = PathManagerBase() +PathManager.register_handler(HTTPURLHandler()) +PathManager.register_handler(OneDrivePathHandler()) + +class _ColorfulFormatter(logging.Formatter): + """ + class: PathManagerBase + """ + def __init__(self, *args, **kwargs): + self._root_name = kwargs.pop('root_name') + '.' + self._abbrev_name = kwargs.pop('abbrev_name', '') + if self._abbrev_name > 0: + self._abbrev_name = self._abbrev_name + '.' + super(_ColorfulFormatter, self).__init__(*args, **kwargs) + + def formatMessage(self, record): + record.name = record.name.replace(self._root_name, self._abbrev_name) + log = super(_ColorfulFormatter, self).formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored('WARNING', 'red', attrs=['blink']) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored('ERROR', 'red', attrs=['blink', 'underline']) + else: + return log + return prefix + ' ' + log + + +@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers +def setup_logger(output=None, distributed_rank=0, *, color=True, name='yolow', abbrev_name=None): + """ + Initialize the yolow logger and set its verbosity level to "DEBUG". + Args: + output (str): a file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + name (str): the root module name of this logger + abbrev_name (str): an abbreviation of the module, to avoid long names in logs. + Set to "" to not log the root module in logs. + By default, will abbreviate "yolow" to "f" and leave other + modules unchanged. + Returns: + logging.Logger: a logger + """ + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.propagate = False + + if abbrev_name is None: + abbrev_name = 'yw' if name == 'yolow' else name + + plain_formatter = logging.Formatter('[%(asctime)s] %(name)s %(levelname)s: %(message)s', datefmt='%m/%d %H:%M:%S') + # stdout logging: master only + if distributed_rank == 0: + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + if color: + formatter = _ColorfulFormatter( + colored('[%(asctime)s %(name)s]: ', 'green') + '%(message)s', + datefmt='%m/%d %H:%M:%S', + root_name=name, + abbrev_name=str(abbrev_name), + ) + else: + formatter = plain_formatter + ch.setFormatter(formatter) + logger.addHandler(ch) + + # file logging: all workers + if output is not None: + if output.endswith('.txt') or output.endswith('.log'): + filename = output + else: + filename = os.path.join(output, 'log.txt') + if distributed_rank > 0: + filename = filename + '.rank{}'.format(distributed_rank) + PathManager.mkdirs(os.path.dirname(filename)) + + fh = logging.StreamHandler(_cached_log_stream(filename)) + fh.setLevel(logging.DEBUG) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + + return logger + + +# cache the opened file object, so that different calls to `setup_logger` +# with the same file name can safely write to the same file. +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + # use 1K buffer if writing to cloud storage + io = PathManager.open(filename, 'a', buffering=1024 if '://' in filename else -1) + atexit.register(io.close) + return io + + +def _find_caller(): + """ + Returns: + str: module name of the caller + tuple: a hashable key to be used to identify different callers + """ + frame = sys._getframe(2) + while frame: + code = frame.f_code + if os.path.join('utils', 'logger.') not in code.co_filename: + mod_name = frame.f_globals['__name__'] + if mod_name == '__main__': + mod_name = 'yolow' + return mod_name, (code.co_filename, frame.f_lineno, code.co_name) + frame = frame.f_back + + +_LOG_COUNTER = Counter() +_LOG_TIMER = {} + + +def log_first_n(lvl, msg, n=1, *, name=None, key='caller'): + """ + Log only for the first n times. + Args: + lvl (int): the logging level + msg (str): + n (int): + name (str): name of the logger to use. Will use the caller's module by default. + key (str or tuple[str]): the string(s) can be one of "caller" or + "message", which defines how to identify duplicated logs. + For example, if called with `n=1, key="caller"`, this function + will only log the first call from the same caller, regardless of + the message content. + If called with `n=1, key="message"`, this function will log the + same content only once, even if they are called from different places. + If called with `n=1, key=("caller", "message")`, this function + will not log only if the same caller has logged the same message before. + """ + if isinstance(key, str): + key = (key,) + assert key + + caller_module, caller_key = _find_caller() + hash_key = () + if 'caller' in key: + hash_key = hash_key + caller_key + if 'message' in key: + hash_key = hash_key + (msg,) + + _LOG_COUNTER[hash_key] += 1 + if _LOG_COUNTER[hash_key] <= n: + logging.getLogger(name or caller_module).log(lvl, msg) + + +def log_every_n(lvl, msg, n=1, *, name=None): + """ + Log once per n times. + Args: + lvl (int): the logging level + msg (str): + n (int): + name (str): name of the logger to use. Will use the caller's module by default. + """ + caller_module, key = _find_caller() + _LOG_COUNTER[key] += 1 + if n == 1 or _LOG_COUNTER[key] % n == 1: + logging.getLogger(name or caller_module).log(lvl, msg) + + +def log_every_n_seconds(lvl, msg, n=1, *, name=None): + """ + Log no more than once per n seconds. + Args: + lvl (int): the logging level + msg (str): + n (int): + name (str): name of the logger to use. Will use the caller's module by default. + """ + caller_module, key = _find_caller() + last_logged = _LOG_TIMER.get(key, None) + current_time = time.time() + if last_logged is None or current_time - last_logged >= n: + logging.getLogger(name or caller_module).log(lvl, msg) + _LOG_TIMER[key] = current_time + + +def create_small_table(small_dict): + """ + Create a small table using the keys of small_dict as headers. This is only + suitable for small dictionaries. + Args: + small_dict (dict): a result dictionary of only a few items. + Returns: + str: the table as a string. + """ + keys, values = tuple(zip(*small_dict.items())) + table = tabulate( + [values], + headers=keys, + tablefmt='pipe', + floatfmt='.3f', + stralign='center', + numalign='center', + ) + return table + + + +_CURRENT_STORAGE_STACK = [] + + +def get_event_storage(): + """ + Returns: + The :class:`EventStorage` object that's currently being used. + Throws an error if no :class:`EventStorage` is currently enabled. + """ + assert _CURRENT_STORAGE_STACK, "get_event_storage() has to be called inside a 'with EventStorage(...)' context!" + return _CURRENT_STORAGE_STACK[-1] + + +class CommonMetricPrinter: + """ + Print **common** metrics to the terminal, including + iteration time, ETA, memory, all losses, and the learning rate. + It also applies smoothing using a window of 20 elements. + + It's meant to print common metrics in common ways. + To print something in more customized ways, please implement a similar printer by yourself. + """ + + def __init__(self, iters_per_epoch: int, max_iter: Optional[int] = None): + """ + Args: + max_iter: the maximum number of iterations to train. + Used to compute ETA. If not given, ETA will not be printed. + """ + # self.logger = Logger(logfile="", level=logging.DEBUG) #logging.getLogger(__name__) + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.DEBUG) + self._max_iter = max_iter + self._iters_per_epoch = iters_per_epoch + self._last_write = None # (step, time) of last call to write(). Used to compute ETA + def _get_eta(self, storage) -> Optional[str]: + """ + Filename: __init__.py + Author: hetao + """ + if self._max_iter is None: + return '' + iteration = storage.iter + try: + eta_seconds = storage.history('time').median(1000) * (self._max_iter - iteration - 1) + storage.put_scalar('eta_seconds', eta_seconds, smoothing_hint=False) + return str(datetime.timedelta(seconds=int(eta_seconds))) + except KeyError: + # estimate eta on our own - more noisy + eta_string = None + if self._last_write is not None: + estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (iteration - self._last_write[0]) + eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + self._last_write = (iteration, time.perf_counter()) + return eta_string + + def write(self): + """ + log write. + """ + # storage = get_event_storage() + # iteration = storage.iter + # epoch = (iteration + 0) // self._iters_per_epoch + # if iteration == self._max_iter: + # # This hook only reports training progress (loss, ETA, etc) but not other data, + # # therefore do not write anything after training succeeds, even if this method + # # is called. + # return + + # try: + # data_time = storage.history('data_time').avg(20) + # except KeyError: + # # they may not exist in the first few iterations (due to warmup) + # # or when SimpleTrainer is not used + # data_time = None + # try: + # iter_time = storage.history('time').global_avg() + # except KeyError: + # iter_time = None + # try: + # lr = '{:.5g}'.format(storage.history('lr').latest()) + # except KeyError: + # lr = 'N/A' + + # eta_string = self._get_eta(storage) + + + + # NOTE: max_mem is parsed by grep in "dev/parse_results.sh" + # self.logger.info('%s epoch: %d iter: %d %s %s%s lr: %s %s' % + # ( 'eta: %s ' % eta_string if eta_string else '', + # epoch, + # iteration, + # ' '.join([ + # '%s: %.6f' % (k, v.latest()) for k, v in storage.histories().items() + # if ('loss' in k) or ('Cider' in k) or ('RewardCriterion' in k) + # ]), + # 'time: %.4f ' % iter_time if iter_time is not None else '', + # 'data_time: %.4f ' % data_time if data_time is not None else '', + # lr, + # 'memory_usage_info' + # )) + + def close(self): + pass diff --git a/community/cv/yolo-world/yolow/model/__init__.py b/community/cv/yolo-world/yolow/model/__init__.py new file mode 100644 index 000000000..d4ffab2b1 --- /dev/null +++ b/community/cv/yolo-world/yolow/model/__init__.py @@ -0,0 +1,13 @@ +""" +Filename: __init__.py +Author: hetao +""" +from .data_preprocessor import YOLOWDetDataPreprocessor +from .factory import (build_yolov8_backbone, build_yoloworld_data_preprocessor, + build_yoloworld_head, build_yoloworld_neck, + build_yoloworld_text, build_yoloworld_backbone, + build_yoloworld_detector) +from .yolo_world_backbone import MultiModalYOLOBackbone + + +__all__ = [k for k in globals().keys() if not k.startswith('_')] diff --git a/community/cv/yolo-world/yolow/model/clip_text.py b/community/cv/yolo-world/yolow/model/clip_text.py new file mode 100644 index 000000000..f7239c044 --- /dev/null +++ b/community/cv/yolo-world/yolow/model/clip_text.py @@ -0,0 +1,90 @@ +""" +Filename: clip_text.py +author: hetao + +""" +import itertools +import warnings +from typing import List, Sequence + +from mindnlp.transformers import CLIPTextModelWithProjection as CLIPTP +from mindnlp.transformers import CLIPTextConfig +from mindnlp.transformers import AutoTokenizer + +from mindspore import Tensor +import mindspore.nn as nn + + +warnings.simplefilter(action='ignore', category=FutureWarning) + +__all__ = ('HuggingCLIPLanguageBackbone',) + + +class HuggingCLIPLanguageBackbone(nn.Cell): + """ + HuggingCLIPLanguageBackbone + """ + + def __init__(self, model_name: str, frozen_modules: Sequence[str] = (), dropout: float = 0.0) -> None: + super().__init__() + + self.frozen_modules = frozen_modules + + # mindspore tokenizer + self.ms_tokenizer = AutoTokenizer.from_pretrained( + model_name, cache_dir=".mindnlp", local_files_only=True) + # mindspore config + ms_configuration = CLIPTextConfig.from_pretrained(model_name, + attention_dropout=dropout, + cache_dir=".mindnlp", + local_files_only=True) + + # mindspore model + self.ms_model = CLIPTP.from_pretrained(model_name, config=ms_configuration, + cache_dir=".mindnlp", local_files_only=True) + + # self._freeze_modules() + + def construct(self, text: List[List[str]]) -> Tensor: + """ + construct + """ + num_per_batch = [len(t) for t in text] + assert max(num_per_batch) == min( + num_per_batch), ('number of sequences not equal in batch') + ms_text = list(itertools.chain(*text)) + + ms_text = self.ms_tokenizer( + text=ms_text, return_tensors='ms', padding=True) + + ms_txt_outputs = self.ms_model(**ms_text) + + # ms_txt_features 和 txt_features 对齐 + ms_txt_features = ms_txt_outputs.text_embeds + ms_txt_features = ms_txt_features / \ + ms_txt_features.norm(dim=-1, keepdim=True) + ms_txt_features = ms_txt_features.reshape( + -1, num_per_batch[0], ms_txt_features.shape[-1]) + return ms_txt_features + + def _freeze_modules(self): + """ + _freeze_modules + """ + if self.frozen_modules: + # not freeze + return + if self.frozen_modules[0] == "all": + self.model.eval() + for _, module in self.model.named_modules(): + module.eval() + for param in module.parameters(): + param.requires_grad = False + return + for name, module in self.model.named_modules(): + for frozen_name in self.frozen_modules: + if name.startswith(frozen_name): + module.eval() + for param in module.parameters(): + param.requires_grad = False + break diff --git a/community/cv/yolo-world/yolow/model/data_preprocessor.py b/community/cv/yolo-world/yolow/model/data_preprocessor.py new file mode 100644 index 000000000..e56b97404 --- /dev/null +++ b/community/cv/yolo-world/yolow/model/data_preprocessor.py @@ -0,0 +1,129 @@ +""" +data_preprocessor.py +""" +import math +from numbers import Number +from typing import List, Optional, Sequence, Union +import numpy as np + +import mindspore.nn as nn +import mindspore as ms + +__all__ = ('YOLOWDetDataPreprocessor',) + + +class YOLOWDetDataPreprocessor(nn.Cell): + """Image pre-processor for detection tasks. + """ + + def __init__(self, + mean: Sequence[Number] = None, + std: Sequence[Number] = None, + pad_size_divisor: int = 1, + pad_value: Union[float, int] = 0, + bgr_to_rgb: bool = False, + rgb_to_bgr: bool = False, + non_blocking: Optional[bool] = True): + super().__init__() + self._non_blocking = non_blocking + + assert not (bgr_to_rgb and rgb_to_bgr), ('`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time') + assert (mean is None) == (std is None), ('mean and std should be both None or tuple') + + if mean is not None: + assert len(mean) == 3 or len(mean) == 1, ('`mean` should have 1 or 3 values, to be compatible with ' + f'RGB or gray image, but got {len(mean)} values') + assert len(std) == 3 or len(std) == 1, ( # type: ignore + '`std` should have 1 or 3 values, to be compatible with RGB ' # type: ignore # noqa: E501 + f'or gray image, but got {len(std)} values') # type: ignore + self._enable_normalize = True + self.mean = ms.Parameter(ms.Tensor(mean).view(-1, 1, 1)) + self.std = ms.Parameter(ms.Tensor(std).view(-1, 1, 1)) + else: + self._enable_normalize = False + self._channel_conversion = rgb_to_bgr or bgr_to_rgb + self.pad_size_divisor = pad_size_divisor + self.pad_value = pad_value + + def _samplelist_boxtype2tensor(self, batch_data_samples) -> list: + """ + _samplelist_boxtype2tensor + """ + new_batch_data_samples = [] + for data_samples in batch_data_samples: + new_data_samples = {'img_metas': {}} + for k, v in data_samples['img_metas'].items(): + # ['texts', 'ori_shape', 'img_id', 'img_path', 'scale_factor', 'img_shape', 'pad_param'] + new_data_samples['img_metas'][k] = v + new_data_samples[k] = v # TODO removed, for debug + if 'gt_instances' in data_samples: + new_data_samples['gt_instances'] = data_samples['gt_instances'] + if 'pred_instances' in data_samples: + new_data_samples['pred_instances'] = data_samples['pred_instances'] + if 'ignored_instances' in data_samples: + new_data_samples['ignored_instances'] = data_samples['ignored_instances'] + new_batch_data_samples.append(new_data_samples) + return new_batch_data_samples + + def construct(self, data: dict, training: bool = False) -> dict: + """Perform normalization, padding and bgr2rgb conversion + """ + inputs, data_samples = data['inputs'], data['data_samples'] + + if not training: + if isinstance(inputs, list): + assert len(inputs) == 1, 'only support batch_size=1 for test' + inputs = ms.ops.stack(inputs) + # elif len(inputs.shape) == 3: + # inputs = ms.ops.unsqueeze(inputs, dim=0) + # import pdb;pdb.set_trace() + # data_samples = self._samplelist_boxtype2tensor(data_samples) + + + assert isinstance(inputs, ms.Tensor) + assert isinstance(data_samples, (dict, list)) + + # TODO: Supports multi-scale training + if self._channel_conversion and inputs.shape[1] == 3: + inputs = inputs[:, [2, 1, 0], ...] + if self._enable_normalize: + inputs = (inputs - self.mean) / self.std + + # not used here + h, w = inputs.shape[2:] + target_h = math.ceil(h / self.pad_size_divisor) * self.pad_size_divisor + target_w = math.ceil(w / self.pad_size_divisor) * self.pad_size_divisor + pad_h = target_h - h + pad_w = target_w - w + inputs = ms.ops.pad(inputs, (0, pad_w, 0, pad_h), 'constant', self.pad_value) + + assert tuple(inputs[0].shape[-2:]) == tuple(inputs.shape[2:]) # debug + if not training: + # for idx, pad_shape in enumerate(batch_pad_shape): + # import pdb;pdb.set_trace() + # data_samples[idx]['img_metas']['batch_input_shape'] = tuple(inputs.shape[2:]) + # data_samples[idx]['img_metas']['pad_shape'] = pad_shape + return {'inputs': inputs, 'data_samples': data_samples} + return {'inputs': inputs,} + # img_metas = [{'batch_input_shape': inputs.shape[2:]}] * len(inputs) + # data_samples_output = { + # 'bboxes_labels': data_samples['bboxes_labels'], + # 'texts': data_samples['texts'], + # 'img_metas': img_metas + # } + # if 'masks' in data_samples: + # data_samples_output['masks'] = data_samples['masks'] + # if 'is_detection' in data_samples: + # data_samples_output['is_detection'] = data_samples['is_detection'] + # return {'inputs': inputs, 'data_samples': data_samples_output} + + def _get_pad_shape(self, _batch_inputs: ms.Tensor) -> List[tuple]: + """Get the pad_shape of each image based on data and + pad_size_divisor.""" + assert _batch_inputs.dim() == 4, ('The input of `DataPreprocessor` should be a NCHW tensor ' + 'or a list of tensor, but got a tensor with shape: ' + f'{_batch_inputs.shape}') + pad_h = int(np.ceil(_batch_inputs.shape[2] / self.pad_size_divisor)) * self.pad_size_divisor + pad_w = int(np.ceil(_batch_inputs.shape[3] / self.pad_size_divisor)) * self.pad_size_divisor + batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0] + return batch_pad_shape diff --git a/community/cv/yolo-world/yolow/model/factory.py b/community/cv/yolo-world/yolow/model/factory.py new file mode 100644 index 000000000..6d8dc0f5e --- /dev/null +++ b/community/cv/yolo-world/yolow/model/factory.py @@ -0,0 +1,337 @@ +""" +factory.py +""" +import json +import os.path as osp + +from dataclasses import dataclass, field, replace +from typing import List, Optional, Tuple +import mindspore.nn as nn +from .clip_text import HuggingCLIPLanguageBackbone + +from .data_preprocessor import YOLOWDetDataPreprocessor +from .misc import yolow_dict # simply replace dict['key'] with dict.key + + +from .yolo_base import YOLOv8CSPDarknet +from .yolo_world import YOLOWorldDetector +from .yolo_world_backbone import MultiModalYOLOBackbone +from .yolo_world_head import YOLOWorldHeadModule, YOLOWorldHead +from .yolo_world_pafpn import YOLOWorldPAFPN + + +# from .clip_t import CLIPTextModel +# from mindformers import CLIPConfig + +# from mindformers import CLIPModel + +__all__ = ( + 'build_yoloworld_data_preprocessor', + 'build_yolov8_backbone', + 'build_yoloworld_text', + 'build_yoloworld_backbone', + 'build_yoloworld_neck', + 'build_yoloworld_head', + 'build_yoloworld_detector', +) + +# default config files for model architectures +# generally we do not need to modify these template files +# you can manually add arguments via `args` when calling functions +CFG_FILES = { + 'n': osp.join(osp.dirname(__file__), 'model_cfgs', 'yoloworld_n.json'), + 's': osp.join(osp.dirname(__file__), 'model_cfgs', 'yoloworld_s.json'), + 'm': osp.join(osp.dirname(__file__), 'model_cfgs', 'yoloworld_m.json'), + 'l': osp.join(osp.dirname(__file__), 'model_cfgs', 'yoloworld_l.json'), + 'x': osp.join(osp.dirname(__file__), 'model_cfgs', 'yoloworld_x.json'), + 'xl': osp.join(osp.dirname(__file__), 'model_cfgs', 'yoloworld_xl.json'), +} + + +def load_config(size, CfgClass=None, key=None, add_args=None): + """ + load_config + """ + assert size in CFG_FILES.keys(), \ + "YOLO-World only supports the following sizes: [n|s|m|l|x|xl]." + # read json file into dict + with open(CFG_FILES[size], 'r') as jf: + cfg = json.load(jf) + assert ((key is None) or (key in cfg.keys())), (f'Unknown key: {key}') + + if CfgClass is not None: + # read from json WITH default settings + if key is None: + cfg = CfgClass() # by default + else: + cfg = CfgClass(**cfg[key]) + else: + # read from json WITHOUT default settings + cfg = yolow_dict(cfg) + if key is not None: + cfg = cfg.key + + # update with manually added config + if add_args is not None: + if key is not None: + # add_args = add_args[key] + add_args = add_args.get(key, dict()) + if isinstance(cfg, CfgClass): + all_keys = cfg.__dataclass_fields__.keys() + elif isinstance(cfg, dict): + all_keys = cfg.keys() + else: + raise ValueError(f'Unknown type: {cfg.__class__}') + # TODO add warning for mismatched kwargs + filtered_kwargs = {k: add_args[k] + for k in all_keys if k in add_args.keys()} + # print (filtered_kwargs) + cfg = replace(cfg, **filtered_kwargs) + return cfg + + +@dataclass +class YOLOWorldDataPreCfg: + mean: List = field(default_factory=lambda: [0., 0., 0.]) + std: List = field(default_factory=lambda: [255., 255., 255.]) + pad_size_divisor: int = 1 + pad_value: float = 0 + bgr_to_rgb: bool = True + rgb_to_bgr: bool = False + non_blocking: bool = True + + +def build_yoloworld_data_preprocessor(size: str, args: Optional[dict] = None) -> nn.Cell: + cfg = load_config(size, YOLOWorldDataPreCfg, + 'yoloworld_data_preprocessor', args) + + return YOLOWDetDataPreprocessor( + mean=cfg.mean, + std=cfg.std, + pad_size_divisor=cfg.pad_size_divisor, + pad_value=cfg.pad_value, + bgr_to_rgb=cfg.bgr_to_rgb, + rgb_to_bgr=cfg.rgb_to_bgr, + non_blocking=cfg.non_blocking, + ) + + +@dataclass +class YOLOv8BackboneCfg: + """ + YOLOv8BackboneCfg + """ + arch: str = 'P5' + last_stage_out_channels: int = 1024 # vary among sizes + deepen_factor: float = 0.33 # vary among sizes + widen_factor: float = 0.5 # vary among sizes + input_channels: int = 3 + out_indices: Tuple[int] = (2, 3, 4) + frozen_stages: int = -1 + with_norm: bool = True + with_activation: bool = True + norm_eval: bool = False + + +def build_yolov8_backbone(size: str, args: Optional[dict] = None) -> nn.Cell: + """Exp. + >>> model = build_yolov8_backbone('s') + """ + cfg = load_config(size, YOLOv8BackboneCfg, 'yolov8_backbone', args) + ms_yolov8_backbone = YOLOv8CSPDarknet( + arch=cfg.arch, + last_stage_out_channels=cfg.last_stage_out_channels, + deepen_factor=cfg.deepen_factor, + widen_factor=cfg.widen_factor, + input_channels=cfg.input_channels, + out_indices=cfg.out_indices, + frozen_stages=cfg.frozen_stages, + with_norm=cfg.with_norm, # BN + with_activation=cfg.with_activation, # SiLU + norm_eval=cfg.norm_eval, + ) + + return ms_yolov8_backbone + + +@dataclass +class YOLOWorldTextCfg: + model_name: str = 'openai/clip-vit-base-patch32' + channels: int = 512 # for `YOLOWorldPAFPN.guide_channels` + frozen_modules: List = field(default_factory=lambda: ['all']) + dropout: float = 0.0 + + +def build_yoloworld_text(size: str, args: Optional[dict] = None) -> nn.Cell: + """ + build_yoloworld_text + """ + cfg = load_config(size, YOLOWorldTextCfg, 'yoloworld_text', args) + + ms_text = HuggingCLIPLanguageBackbone( + model_name=cfg.model_name, + frozen_modules=cfg.frozen_modules, + dropout=cfg.dropout, + ) + # config = CLIPConfig.from_pretrained("clip_vit_b_32") + # ms_text = CLIPTextModel(config) + + return ms_text + + +@dataclass +class YOLOWorldBackboneCfg: + frozen_stages: int = -1 + with_text_model: bool = True + + +def build_yoloworld_backbone( + size: str, + image_model: nn.Cell, # from `build_yolov8_backbone` + text_model: nn.Cell, # from `build_yoloworld_text` + args: Optional[dict] = None) -> nn.Cell: + """ + build_yoloworld_backbone + """ + cfg = load_config(size, YOLOWorldBackboneCfg, 'yoloworld_backbone', args) + ms_yoloworld_backbone = MultiModalYOLOBackbone( + image_model=image_model, + text_model=text_model, + frozen_stages=cfg.frozen_stages, + with_text_model=cfg.with_text_model, + ) + return ms_yoloworld_backbone + + +@dataclass +class YOLOWorldNeckCfg: + in_channels: List = field(default_factory=lambda: [256, 512, 1024]) + out_channels: List = field(default_factory=lambda: [256, 512, 1024]) + embed_channels: List = field(default_factory=lambda: [128, 256, 512]) + num_heads: List = field(default_factory=lambda: [4, 8, 16]) + num_csp_blocks: int = 3 + freeze_all: bool = False + with_norm: bool = True + with_activation: bool = True + + +def build_yoloworld_neck(size: str, args: Optional[dict] = None) -> nn.Cell: + """ + build_yoloworld_neck + """ + cfg = load_config(size, YOLOWorldNeckCfg, 'yoloworld_neck', args) + img_cfg = load_config(size, YOLOv8BackboneCfg, 'yolov8_backbone', args) + text_cfg = load_config(size, YOLOWorldTextCfg, 'yoloworld_text', args) + + ms_neck = YOLOWorldPAFPN( + in_channels=cfg.in_channels, + out_channels=cfg.out_channels, + guide_channels=text_cfg.channels, # determined by text encoder + embed_channels=cfg.embed_channels, + num_heads=cfg.num_heads, + deepen_factor=img_cfg.deepen_factor, + widen_factor=img_cfg.widen_factor, + num_csp_blocks=cfg.num_csp_blocks, + freeze_all=cfg.freeze_all, + with_norm=cfg.with_norm, + with_activation=cfg.with_activation, + ) + + return ms_neck + + +@dataclass +class YOLOWorldHeadModuleCfg: + use_bn_head: bool = True + use_einsum: bool = True + freeze_all: bool = False + num_base_priors: int = 1 + featmap_strides: List = field(default_factory=lambda: [8, 16, 32]) + reg_max: int = 16 + with_norm: bool = True + with_activation: bool = True + + +def build_yoloworld_head( + size: str, + multi_label: bool = True, # test_cfg + nms_pre: int = 30000, # test_cfg + score_thr: float = 0.001, # test_cfg + nms_iou_threshold: float = 0.7, # test_cfg + max_per_img: int = 300, # test_cfg + args: Optional[dict] = None) -> nn.Cell: + """ + build_yoloworld_head + """ + cfg = load_config(size, YOLOWorldHeadModuleCfg, + 'yoloworld_head_module', args) + text_cfg = load_config(size, YOLOWorldTextCfg, 'yoloworld_text', args) + yolo_cfg = load_config(size, YOLOv8BackboneCfg, 'yolov8_backbone', args) + neck_cfg = load_config(size, YOLOWorldNeckCfg, 'yoloworld_neck', args) + det_cfg = load_config(size, YOLOWorldDetectorCfg, + 'yoloworld_detector', args) + + # `test_cfg` is not kind of model arch + # so we did not define it in the json file + # determine the arguments when calling `build_yoloworld_head()` + test_cfg = yolow_dict( + # The config of multi-label for multi-class prediction. + multi_label=multi_label, + # The number of boxes before NMS + nms_pre=nms_pre, + score_thr=score_thr, # Threshold to filter out boxes. + # NMS type and threshold + nms=yolow_dict(type='nms', iou_threshold=nms_iou_threshold), + max_per_img=max_per_img) # Max number of detections of each image + + ms_head_module = YOLOWorldHeadModule( + num_classes=det_cfg.num_train_classes, + in_channels=neck_cfg.in_channels, # determined by neck + embed_dims=text_cfg.channels, # determined by text encoder + use_bn_head=cfg.use_bn_head, + use_einsum=cfg.use_einsum, + freeze_all=cfg.freeze_all, + widen_factor=yolo_cfg.widen_factor, # determined by yolov8 + num_base_priors=cfg.num_base_priors, + featmap_strides=cfg.featmap_strides, + reg_max=cfg.reg_max, + with_norm=cfg.with_norm, + with_activation=cfg.with_activation, + ) + + ms_head = YOLOWorldHead( + ms_head_module, + test_cfg=test_cfg, + ) + return ms_head + + +@dataclass +class YOLOWorldDetectorCfg: + mm_neck: bool = True + use_syncbn: bool = True + num_train_classes: int = 80 + num_test_classes: int = 80 + + +def build_yoloworld_detector(size: str, + backbone: nn.Cell, + neck: nn.Cell, + bbox_head: nn.Cell, + data_preprocessor: Optional[nn.Cell] = None, + args: Optional[dict] = None) -> nn.Cell: + """ + build_yoloworld_detector + """ + cfg = load_config(size, YOLOWorldDetectorCfg, 'yoloworld_detector', args) + ms_detector = YOLOWorldDetector( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + mm_neck=cfg.mm_neck, + num_train_classes=cfg.num_train_classes, + num_test_classes=cfg.num_test_classes, + data_preprocessor=data_preprocessor, + ) + + return ms_detector diff --git a/community/cv/yolo-world/yolow/model/layers/__init__.py b/community/cv/yolo-world/yolow/model/layers/__init__.py new file mode 100644 index 000000000..b188c1657 --- /dev/null +++ b/community/cv/yolo-world/yolow/model/layers/__init__.py @@ -0,0 +1,16 @@ +""" +__init__.py +""" +from .attn import MaxSigmoidAttnBlock +from .bottleneck import Bottleneck, SPPFBottleneck +from .conv import Conv +from .csp_layer import CSPLayer, MaxSigmoidCSPLayer + +__all__ = ( + 'MaxSigmoidAttnBlock', + 'Conv', + 'CSPLayer', + 'Bottleneck', + 'SPPFBottleneck', + 'MaxSigmoidCSPLayer', +) diff --git a/community/cv/yolo-world/yolow/model/layers/attn.py b/community/cv/yolo-world/yolow/model/layers/attn.py new file mode 100644 index 000000000..b22ae82f2 --- /dev/null +++ b/community/cv/yolo-world/yolow/model/layers/attn.py @@ -0,0 +1,94 @@ +""" +attn.py +""" +import mindspore.nn as nn +from mindspore import Tensor +import mindspore as ms +from .conv import Conv +__all__ = ('MaxSigmoidAttnBlock',) + + +class MaxSigmoidAttnBlock(nn.Cell): + """Max Sigmoid attention block.""" + + def __init__(self, + in_channels: int, + out_channels: int, + guide_channels: int, + embed_channels: int, + kernel_size: int = 3, + padding: int = 1, + num_heads: int = 1, + with_scale: bool = False, + with_norm: bool = True, + use_einsum: bool = True) -> None: + super().__init__() + + assert (out_channels % num_heads == 0 and + embed_channels % num_heads == 0), \ + 'out_channels and embed_channels should be divisible by num_heads.' + self.num_heads = num_heads + self.head_channels = out_channels // num_heads + self.use_einsum = use_einsum + + self.embed_conv = Conv( + in_channels, embed_channels, 1, with_norm=with_norm, + with_activation=False) if embed_channels != in_channels else None + + self.guide_fc = nn.Dense(guide_channels, embed_channels) + self.bias = ms.Parameter(ms.ops.zeros(num_heads)) + if with_scale: + self.scale = ms.Parameter(ms.ops.ones(1, num_heads, 1, 1)) + else: + self.scale = 1.0 + + self.project_conv = Conv( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=padding, + with_norm=with_norm, + with_activation=False) + + def construct(self, x: Tensor, guide: Tensor) -> Tensor: + """ + construct + """ + B, _, H, W = x.shape + + guide = self.guide_fc(guide) + # guide = guide.reshape(B, -1, self.num_heads, self.head_channels) + guide = ms.ops.reshape( + guide, (B, -1, self.num_heads, self.head_channels)) + embed = self.embed_conv(x) if self.embed_conv is not None else x + # embed = embed.reshape(B, self.num_heads, self.head_channels, H, W) + embed = ms.ops.reshape( + embed, (B, self.num_heads, self.head_channels, H, W)) + + if self.use_einsum: + # attn_weight = torch.einsum('bmchw,bnmc->bmhwn', embed, guide) + attn_weight = ms.ops.einsum('bmchw,bnmc->bmhwn', embed, guide) + else: + batch, m, channel, height, width = embed.shape + _, n, _, _ = guide.shape + embed = embed.permute(0, 1, 3, 4, 2) + embed = embed.reshape(batch, m, -1, channel) + guide = guide.permute(0, 2, 3, 1) + # attn_weight = torch.matmul(embed, guide) + attn_weight = ms.ops.matmul(embed, guide) + attn_weight = attn_weight.reshape(batch, m, height, width, n) + + attn_weight = attn_weight.max(axis=-1)[0] + attn_weight = attn_weight / (self.head_channels**0.5) + attn_weight = attn_weight + self.bias[None, :, None, None] + attn_weight = ms.ops.sigmoid(attn_weight) * self.scale + + x = self.project_conv(x) + # x = x.reshape(B, self.num_heads, -1, H, W) + x = ms.ops.reshape(x, (B, self.num_heads, -1, H, W)) + # x = x * attn_weight.unsqueeze(2) + x = x * ms.ops.unsqueeze(attn_weight, 2) + # x = x.reshape(B, -1, H, W) + x = ms.ops.reshape(x, (B, -1, H, W)) + return x diff --git a/community/cv/yolo-world/yolow/model/layers/bottleneck.py b/community/cv/yolo-world/yolow/model/layers/bottleneck.py new file mode 100644 index 000000000..5d16ff0d9 --- /dev/null +++ b/community/cv/yolo-world/yolow/model/layers/bottleneck.py @@ -0,0 +1,109 @@ +""" +bottleneck.py +""" +from typing import Sequence, Union + +import mindspore.nn as nn +import mindspore as ms +from mindspore import Tensor +from .conv import Conv +__all__ = ( + 'Bottleneck', + 'SPPFBottleneck', +) + + +class Bottleneck(nn.Cell): + """The basic bottleneck block used in Darknet. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + expansion: float = 0.5, + kernel_size: Sequence[int] = (1, 3), + padding: Sequence[int] = (0, 1), + add_identity: bool = True, + with_norm: bool = False, + with_activation: bool = True) -> None: + super().__init__() + + hidden_channels = int(out_channels * expansion) + assert isinstance(kernel_size, Sequence) and len(kernel_size) == 2 + + self.conv1 = Conv( + in_channels, + hidden_channels, + kernel_size[0], + padding=padding[0], + with_norm=with_norm, + with_activation=with_activation) + self.conv2 = Conv( + hidden_channels, + out_channels, + kernel_size[1], + stride=1, + padding=padding[1], + with_norm=with_norm, + with_activation=with_activation) + self.add_identity = \ + add_identity and in_channels == out_channels + + def construct(self, x: Tensor) -> Tensor: + identity = x + out = self.conv1(x) + out = self.conv2(out) + + if self.add_identity: + return out + identity + return out + + +class SPPFBottleneck(nn.Cell): + """Spatial pyramid pooling + """ + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_sizes: Union[int, Sequence[int]] = 5, + use_conv_first: bool = True, + mid_channels_scale: float = 0.5, + with_norm: bool = True, + with_activation: bool = True): + super().__init__() + + if use_conv_first: + mid_channels = int(in_channels * mid_channels_scale) + self.conv1 = Conv( + in_channels, mid_channels, 1, stride=1, with_norm=with_norm, with_activation=with_activation) + else: + mid_channels = in_channels + self.conv1 = None + self.kernel_sizes = kernel_sizes + if isinstance(kernel_sizes, int): + self.poolings = nn.MaxPool2d( + kernel_size=kernel_sizes, stride=1, pad_mode="pad", padding=kernel_sizes // 2) + conv2_in_channels = mid_channels * 4 + else: + self.poolings = nn.CellList( + [nn.MaxPool2d(kernel_size=ks, stride=1, pad_mode="pad", padding=ks // 2) for ks in kernel_sizes]) + conv2_in_channels = mid_channels * (len(kernel_sizes) + 1) + + self.conv2 = Conv(conv2_in_channels, out_channels, 1, + with_norm=with_norm, with_activation=with_activation) + + def construct(self, x: Tensor) -> Tensor: + """ + construct + """ + if self.conv1: + x = self.conv1(x) + if isinstance(self.kernel_sizes, int): + y1 = self.poolings(x) + y2 = self.poolings(y1) + x = ms.ops.cat([x, y1, y2, self.poolings(y2)], axis=1) + else: + x = ms.ops.cat([x] + [pooling(x) for pooling in self.poolings], axis=1) + x = self.conv2(x) + return x diff --git a/community/cv/yolo-world/yolow/model/layers/conv.py b/community/cv/yolo-world/yolow/model/layers/conv.py new file mode 100644 index 000000000..c6d0fd4db --- /dev/null +++ b/community/cv/yolo-world/yolow/model/layers/conv.py @@ -0,0 +1,94 @@ +""" +conv.py +""" +from typing import Tuple, Union +import mindspore.nn as nn +from mindspore import Tensor + +__all__ = ('Conv',) + + +class Conv(nn.Cell): + """A convolution block + composed of conv/norm/activation layers. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + with_norm: bool = False, + with_activation: bool = True, + bias: Union[bool, str] = 'auto'): + super().__init__() + self.with_norm = with_norm + self.with_activation = with_activation + # if the conv layer is before a norm layer, bias is unnecessary. + if bias == 'auto': + bias = not self.with_norm + self.with_bias = bias + + # pytorch与mindspore 的 nn.Conv2d 区别 + # build convolution layer + if padding > 0: + pad_mode = 'pad' + else: + pad_mode = 'valid' + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + pad_mode=pad_mode, + padding=padding, + dilation=dilation, + group=groups, + has_bias=self.with_bias, + ) + + # build normalization layers + if self.with_norm: + # 官方说 pytorch这里的BN中momentum_torch = 1 - momentum_mindspore + self.bn = nn.BatchNorm2d(out_channels, momentum=1-0.03, eps=0.001) + # build activation layer + if self.with_activation: + self.activate = nn.SiLU() # mindspore has no inplace + + # self.init_weights() + + def construct(self, x: Tensor) -> Tensor: + # fixed order: ('conv', 'norm', 'act') + x = self.conv(x) + if self.with_norm: + x = self.bn(x) + if self.with_activation: + x = self.activate(x) + return x + + # 推理按理说应该不看重init + def init_weights(self): + """ + init_weights + """ + from mindspore.common.initializer import initializer + # nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu') + # self.conv.weight.set_data(initializer(HeNormal(negative_slope=0, mode='fan_out', nonlinearity='relu'), + # self.conv.weight.shape, self.conv.weight.dtype)) + if hasattr(self.conv, 'bias') and self.conv.bias is not None: + # nn.init.constant_(self.conv.bias, 0) + self.conv.bias.set_data(initializer( + "zeros", self.conv.bias.shape, self.conv.bias.dtype)) + + if self.with_norm: + # nn.init.constant_(self.bn.weight, 1) + self.bn.weight.set_data(initializer( + "ones", self.bn.weight.shape, self.bn.weight.dtype)) + + if hasattr(self.conv, 'bias') and self.conv.bias is not None: + # nn.init.constant_(self.bn.bias, 0) + self.bn.bias.set_data(initializer( + "zeros", self.bn.bias.shape, self.bn.bias.dtype)) diff --git a/community/cv/yolo-world/yolow/model/layers/csp_layer.py b/community/cv/yolo-world/yolow/model/layers/csp_layer.py new file mode 100644 index 000000000..121acf95e --- /dev/null +++ b/community/cv/yolo-world/yolow/model/layers/csp_layer.py @@ -0,0 +1,104 @@ +""" +csp_layer.py +""" +import mindspore.nn as nn +import mindspore as ms +from mindspore import Tensor + +from .attn import MaxSigmoidAttnBlock +from .bottleneck import Bottleneck +from .conv import Conv + +__all__ = ( + 'CSPLayer', + 'MaxSigmoidCSPLayer', +) + + +class CSPLayer(nn.Cell): + """Cross Stage Partial Layer. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + expand_ratio: float = 0.5, + num_blocks: int = 1, + add_identity: bool = True, + with_norm: bool = True, + with_activation: bool = True) -> None: + super().__init__() + + self.mid_channels = int(out_channels * expand_ratio) + self.main_conv = Conv( + in_channels, 2 * self.mid_channels, 1, with_norm=with_norm, with_activation=with_activation) + self.final_conv = Conv( + (2 + num_blocks) * self.mid_channels, out_channels, 1, with_norm=with_norm, with_activation=with_activation) + + self.blocks = nn.CellList( + [Bottleneck( + self.mid_channels, + self.mid_channels, + expansion=1, + kernel_size=(3, 3), + padding=(1, 1), + add_identity=add_identity, + with_norm=with_norm, + with_activation=with_activation) for _ in range(num_blocks)]) + + def construct(self, x: Tensor) -> Tensor: + x_main = self.main_conv(x) + x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1)) + x_main.extend(blocks(x_main[-1]) for blocks in self.blocks) + + out = self.final_conv(ms.ops.cat(x_main, 1)) + + return out + + +class MaxSigmoidCSPLayer(CSPLayer): + """Sigmoid-attention based CSP layer. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + guide_channels: int, + embed_channels: int, + num_heads: int = 1, + expand_ratio: float = 0.5, + num_blocks: int = 1, + with_scale: bool = False, + add_identity: bool = True, # shortcut + with_norm: bool = True, + with_activation: bool = True, + use_einsum: bool = True) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + expand_ratio=expand_ratio, + num_blocks=num_blocks, + add_identity=add_identity, + with_norm=with_norm, + with_activation=with_activation) + + self.final_conv = Conv( + (3 + num_blocks) * self.mid_channels, out_channels, 1, with_norm=with_norm, with_activation=with_activation) + + self.attn_block = MaxSigmoidAttnBlock( + self.mid_channels, + self.mid_channels, + guide_channels=guide_channels, + embed_channels=embed_channels, + num_heads=num_heads, + with_scale=with_scale, + with_norm=with_norm, + use_einsum=use_einsum) + + def construct(self, x: Tensor, guide: Tensor) -> Tensor: + x_main = self.main_conv(x) + x_main = list(x_main.split((self.mid_channels, self.mid_channels), 1)) + x_main.extend(blocks(x_main[-1]) for blocks in self.blocks) + x_main.append(self.attn_block(x_main[-1], guide)) + return self.final_conv(ms.ops.cat(x_main, 1)) diff --git a/community/cv/yolo-world/yolow/model/misc.py b/community/cv/yolo-world/yolow/model/misc.py new file mode 100644 index 000000000..c2f73160f --- /dev/null +++ b/community/cv/yolo-world/yolow/model/misc.py @@ -0,0 +1,332 @@ +""" +misc.py +""" +import math +from functools import partial +from typing import Any, Dict, Optional, Tuple, Type, Union +from collections import abc +import numpy as np + +import mindspore.nn as nn +import mindspore as ms +from mindspore import Tensor + +__all__ = ( + 'yolow_dict', + 'is_seq_of', + 'is_list_of', + 'make_divisible', + 'make_round', + 'multi_apply', + 'unpack_gt_instances', + 'get_prior_xy_info', + 'get_box_wh', + 'nms', +) + + +class yolow_dict(dict): + + def __getattr__(self, item): + try: + return self[item] + except KeyError as e: + raise AttributeError from e + + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +def is_seq_of(seq: Any, expected_type: Union[Type, tuple], seq_type: Type = None) -> bool: + """Check whether it is a sequence of some type. + """ + if seq_type is None: + exp_seq_type = abc.Sequence + else: + assert isinstance(seq_type, type) + exp_seq_type = seq_type + if not isinstance(seq, exp_seq_type): + return False + for item in seq: + if not isinstance(item, expected_type): + return False + return True + + +def is_list_of(seq, expected_type): + """Check whether it is a list of some type. + """ + return is_seq_of(seq, expected_type, seq_type=list) + + +def make_divisible(x: float, widen_factor: float = 1.0, divisor: int = 8) -> int: + """Make sure that x*widen_factor is divisible by divisor.""" + return math.ceil(x * widen_factor / divisor) * divisor + + +def make_round(x: float, deepen_factor: float = 1.0) -> int: + """Make sure that x*deepen_factor becomes an integer not less than 1.""" + return max(round(x * deepen_factor), 1) if x > 1 else x + + +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def unpack_gt_instances(batch_data_samples: list) -> tuple: + """Unpack ``gt_instances``, ``gt_instances_ignore`` and ``img_metas`` based + on ``batch_data_samples`` + """ + batch_gt_instances = [] + batch_gt_instances_ignore = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + if 'ignored_instances' in data_sample: + batch_gt_instances_ignore.append(data_sample.ignored_instances) + else: + batch_gt_instances_ignore.append(None) + return batch_gt_instances, batch_gt_instances_ignore, batch_img_metas + + +def ms_filter_scores_and_topk(scores, score_thr, topk, results=None): + """Filter results using score threshold and topk candidates. + """ + valid_mask = scores > score_thr + scores = scores[valid_mask] + valid_idxs = ms.ops.nonzero(valid_mask) + + num_topk = min(topk, valid_idxs.shape[0]) + # torch.sort is actually faster than .topk (at least on GPUs) + scores, idxs = scores.sort(descending=True) + scores = scores[:num_topk] + topk_idxs = valid_idxs[idxs[:num_topk]] + + keep_idxs, labels = topk_idxs.unbind(dim=1) + + filtered_results = None + if results is not None: + if isinstance(results, dict): + filtered_results = {k: v[keep_idxs] for k, v in results.items()} + elif isinstance(results, list): + filtered_results = [result[keep_idxs] for result in results] + elif isinstance(results, ms.Tensor): + filtered_results = results[keep_idxs] + else: + raise NotImplementedError(f'Only supports dict or list or Tensor, ' + f'but get {type(results)}.') + return scores, labels, keep_idxs, filtered_results + + +def get_prior_xy_info(index: int, num_base_priors: int, featmap_sizes: int) -> Tuple[int, int, int]: + """Get prior index and xy index in feature map by flatten index.""" + _, featmap_w = featmap_sizes + priors = index % num_base_priors + xy_index = index // num_base_priors + grid_y = xy_index // featmap_w + grid_x = xy_index % featmap_w + return priors, grid_x, grid_y + + +def ms_scale_boxes(boxes: Union[Tensor, dict], scale_factor: Tuple[float, float]) -> Union[Tensor, dict]: + """Scale boxes with type of tensor or box type. + """ + if isinstance(boxes, dict): + boxes.rescale_(scale_factor) + return boxes + + # Tensor boxes will be treated as horizontal boxes + repeat_num = int(boxes.size(-1) / 2) + scale_factor = ms.Tensor(scale_factor).repeat((1, repeat_num)) + return boxes * scale_factor + + +def get_box_wh(boxes: Union[Tensor, dict]) -> Tuple[Tensor, Tensor]: + """Get the width and height of boxes with type of tensor or box type. + """ + if isinstance(boxes, dict): + w = boxes.widths + h = boxes.heights + else: + # Tensor boxes will be treated as horizontal boxes by defaults + w = boxes[:, 2] - boxes[:, 0] + h = boxes[:, 3] - boxes[:, 1] + return w, h + + +class NMSop(nn.Cell): + """ + NMSop + """ + + def __init__(self, iou_threshold): + super().__init__(iou_threshold) + self.nms_func = ms.ops.NMSWithMask(iou_threshold) + + # @staticmethod + + def construct(self, bboxes: Tensor, scores: Tensor, score_threshold: float, max_num: int) -> Tensor: + """ + construct + """ + is_filtering_by_score = score_threshold > 0 + if is_filtering_by_score: + valid_mask = scores > score_threshold + bboxes, scores = bboxes[valid_mask], scores[valid_mask] + valid_inds = ms.ops.nonzero(valid_mask).squeeze(axis=1) + + # inds = ext_module.nms( + # bboxes, scores, iou_threshold=float(iou_threshold), offset=offset) + # inds = box_ops.batched_nms(bboxes.float(), scores, torch.ones(bboxes.size(0)), iou_threshold) + + box_with_score = ms.ops.concat((bboxes, scores.unsqueeze(-1)), axis=1) + + _, indices, selected_mask = self.nms_func(box_with_score) + + inds = indices[selected_mask] + + if max_num > 0: + inds = inds[:max_num] + if is_filtering_by_score: + inds = valid_inds[inds] + return inds + + +def nms(nms_op, + cat_op, + boxes: Union[Tensor, np.ndarray], + scores: Union[Tensor, np.ndarray], + iou_threshold: float, + offset: int = 0, + score_threshold: float = 0, + max_num: int = -1, + ) -> Tuple[Union[Tensor, np.ndarray], Union[Tensor, np.ndarray]]: + """ + nms + """ + assert isinstance(boxes, (np.ndarray, Tensor)) + assert isinstance(scores, (np.ndarray, Tensor)) + is_numpy = False + if isinstance(boxes, np.ndarray): + is_numpy = True + boxes = ms.Tensor(boxes) + if isinstance(scores, np.ndarray): + scores = ms.Tensor(scores) + assert boxes.shape[1] == 4 + assert boxes.shape[0] == scores.shape[0] + assert offset in (0, 1) + + # if isinstance(boxes, Tensor): + + # start = time.perf_counter() + inds = nms_op(boxes, scores, iou_threshold, + offset, score_threshold, max_num) + # topk_time = time.perf_counter() + # print(f"nms op time: {topk_time - start}") + # start = time.perf_counter() + dets = cat_op((boxes[inds], scores[inds].unsqueeze(-1))) + # topk_time = time.perf_counter() + # print(f"cat time: {topk_time - start}") + if is_numpy: + dets = dets.asnumpy() + inds = inds.asnumpy() + return dets, inds + + +def ms_batched_nms(boxes: Tensor, + scores: Tensor, + idxs: Tensor, + nms_cfg: Optional[Dict], + class_agnostic: bool = False) -> Tuple[Tensor, Tensor]: + r"""Performs non-maximum suppression in a batched fashion. + """ + # skip nms when nms_cfg is None + if nms_cfg is None: + scores, inds = scores.sort(descending=True) + boxes = boxes[inds] + return ms.ops.cat([boxes, scores[:, None]], -1), inds + + nms_cfg_ = nms_cfg.copy() + class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic) + + if class_agnostic: + boxes_for_nms = boxes + else: + # When using rotated boxes, only apply offsets on center. + if boxes.shape[-1] == 5: + # Strictly, the maximum coordinates of the rotating box + # (x,y,w,h,a) should be calculated by polygon coordinates. + # But the conversion from rotated box to polygon will + # slow down the speed. + # So we use max(x,y) + max(w,h) as max coordinate + # which is larger than polygon max coordinate + # max(x1, y1, x2, y2,x3, y3, x4, y4) + max_coordinate = boxes[..., :2].max() + boxes[..., 2:4].max() + offsets = idxs.to(boxes) * (max_coordinate + + ms.Tensor(1).to(boxes)) + boxes_ctr_for_nms = boxes[..., :2] + offsets[:, None] + boxes_for_nms = ms.ops.cat( + [boxes_ctr_for_nms, boxes[..., 2:5]], axis=-1) + else: + max_coordinate = boxes.max() + offsets = idxs.to(ms.float32) * (max_coordinate + + ms.Tensor(1.).to(ms.float32)) + + boxes_for_nms = boxes + offsets[:, None] + + _ = nms_cfg_.pop('type', 'nms') + + nms_op = NMSop(nms_cfg_['iou_threshold']) + cat_op = ms.ops.Concat(1) + # if isinstance(nms_op, str): + # nms_op = eval(nms_op) + + split_thr = nms_cfg_.pop('split_thr', 100000) + # Won't split to multiple nms nodes when exporting to onnx + if boxes_for_nms.shape[0] < split_thr: + dets, keep = nms(nms_op, cat_op, boxes_for_nms, scores, **nms_cfg_) + boxes = boxes[keep] + + # This assumes `dets` has arbitrary dimensions where + # the last dimension is score. + # Currently it supports bounding boxes [x1, y1, x2, y2, score] or + # rotated boxes [cx, cy, w, h, angle_radian, score]. + + scores = dets[:, -1] + else: + max_num = nms_cfg_.pop('max_num', -1) + total_mask = scores.new_zeros(scores.shape, dtype=ms.bool_) + # Some type of nms would reweight the score, such as SoftNMS + scores_after_nms = scores.new_zeros(scores.shape) + for idx in ms.ops.unique(idxs)[0]: + mask = (idxs == idx).nonzero().view(-1) + # start = time.perf_counter() + dets, keep = nms( + nms_op, cat_op, boxes_for_nms[mask], scores[mask], **nms_cfg_) + # nms_time = time.perf_counter() + # time_nms += nms_time - start + total_mask = ms.ops.tensor_scatter_elements(total_mask, mask[keep], + ms.ops.full_like(mask[keep], True, dtype=ms.bool_)) + scores_after_nms = ms.ops.tensor_scatter_elements( + scores_after_nms, mask[keep], dets[:, -1]) + + # print(f"nms time: {time_nms}") + + keep = total_mask.nonzero().view(-1) + + scores, inds = scores_after_nms[keep].sort(descending=True) + keep = keep[inds] + boxes = boxes[keep] + + if max_num > 0: + keep = keep[:max_num] + boxes = boxes[:max_num] + scores = scores[:max_num] + + boxes = ms.ops.cat([boxes, scores[:, None]], -1) + return boxes, keep diff --git a/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_l.json b/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_l.json new file mode 100644 index 000000000..6e4cf384a --- /dev/null +++ b/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_l.json @@ -0,0 +1,37 @@ +{ + "yoloworld_data_preprocessor": { + "mean": [0, 0, 0], + "std": [255, 255, 255], + "bgr_to_rgb": true, + "non_blocking": true + }, + "yolov8_backbone": { + "last_stage_out_channels": 512, + "deepen_factor": 1.00, + "widen_factor": 1.00 + }, + "yoloworld_text": { + "model_name": "openai/clip-vit-base-patch32", + "frozen_modules": ["all"], + "channels": 512 + }, + "yoloworld_neck": { + "in_channels": [256, 512, 512], + "out_channels": [256, 512, 512], + "embed_channels": [128, 256, 256], + "num_heads": [4, 8, 8] + }, + "yoloworld_backbone": { + "with_text_model": true + }, + "yoloworld_head_module": { + "use_bn_head": true, + "use_einsum": true + }, + "yoloworld_detector": { + "num_train_classes": 80, + "num_test_classes": 1203, + "mm_neck": true, + "use_syncbn": true + } +} diff --git a/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_m.json b/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_m.json new file mode 100644 index 000000000..525ee8f3d --- /dev/null +++ b/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_m.json @@ -0,0 +1,37 @@ +{ + "yoloworld_data_preprocessor": { + "mean": [0, 0, 0], + "std": [255, 255, 255], + "bgr_to_rgb": true, + "non_blocking": true + }, + "yolov8_backbone": { + "last_stage_out_channels": 768, + "deepen_factor": 0.67, + "widen_factor": 0.75 + }, + "yoloworld_text": { + "model_name": "openai/clip-vit-base-patch32", + "frozen_modules": ["all"], + "channels": 512 + }, + "yoloworld_neck": { + "in_channels": [256, 512, 768], + "out_channels": [256, 512, 768], + "embed_channels": [128, 256, 384], + "num_heads": [4, 8, 12] + }, + "yoloworld_backbone": { + "with_text_model": true + }, + "yoloworld_head_module": { + "use_bn_head": true, + "use_einsum": true + }, + "yoloworld_detector": { + "num_train_classes": 80, + "num_test_classes": 1203, + "mm_neck": true, + "use_syncbn": true + } +} diff --git a/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_n.json b/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_n.json new file mode 100644 index 000000000..de8c865e7 --- /dev/null +++ b/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_n.json @@ -0,0 +1,37 @@ +{ + "yoloworld_data_preprocessor": { + "mean": [0, 0, 0], + "std": [255, 255, 255], + "bgr_to_rgb": true, + "non_blocking": true + }, + "yolov8_backbone": { + "last_stage_out_channels": 1024, + "deepen_factor": 0.33, + "widen_factor": 0.25 + }, + "yoloworld_text": { + "model_name": "openai/clip-vit-base-patch32", + "frozen_modules": ["all"], + "channels": 512 + }, + "yoloworld_neck": { + "in_channels": [256, 512, 1024], + "out_channels": [256, 512, 1024], + "embed_channels": [128, 256, 512], + "num_heads": [4, 8, 16] + }, + "yoloworld_backbone": { + "with_text_model": true + }, + "yoloworld_head_module": { + "use_bn_head": true, + "use_einsum": true + }, + "yoloworld_detector": { + "num_train_classes": 80, + "num_test_classes": 1203, + "mm_neck": true, + "use_syncbn": true + } +} diff --git a/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_s.json b/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_s.json new file mode 100644 index 000000000..77dab9d1c --- /dev/null +++ b/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_s.json @@ -0,0 +1,37 @@ +{ + "yoloworld_data_preprocessor": { + "mean": [0, 0, 0], + "std": [255, 255, 255], + "bgr_to_rgb": true, + "non_blocking": true + }, + "yolov8_backbone": { + "last_stage_out_channels": 1024, + "deepen_factor": 0.33, + "widen_factor": 0.5 + }, + "yoloworld_text": { + "model_name": "openai/clip-vit-base-patch32", + "frozen_modules": ["all"], + "channels": 512 + }, + "yoloworld_neck": { + "in_channels": [256, 512, 1024], + "out_channels": [256, 512, 1024], + "embed_channels": [128, 256, 512], + "num_heads": [4, 8, 16] + }, + "yoloworld_backbone": { + "with_text_model": true + }, + "yoloworld_head_module": { + "use_bn_head": true, + "use_einsum": true + }, + "yoloworld_detector": { + "num_train_classes": 80, + "num_test_classes": 1203, + "mm_neck": true, + "use_syncbn": true + } +} diff --git a/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_x.json b/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_x.json new file mode 100644 index 000000000..2ac7b4ba1 --- /dev/null +++ b/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_x.json @@ -0,0 +1,37 @@ +{ + "yoloworld_data_preprocessor": { + "mean": [0, 0, 0], + "std": [255, 255, 255], + "bgr_to_rgb": true, + "non_blocking": true + }, + "yolov8_backbone": { + "last_stage_out_channels": 512, + "deepen_factor": 1.00, + "widen_factor": 1.25 + }, + "yoloworld_text": { + "model_name": "openai/clip-vit-base-patch32", + "frozen_modules": ["all"], + "channels": 512 + }, + "yoloworld_neck": { + "in_channels": [256, 512, 512], + "out_channels": [256, 512, 512], + "embed_channels": [128, 256, 256], + "num_heads": [4, 8, 8] + }, + "yoloworld_backbone": { + "with_text_model": true + }, + "yoloworld_head_module": { + "use_bn_head": true, + "use_einsum": true + }, + "yoloworld_detector": { + "num_train_classes": 80, + "num_test_classes": 1203, + "mm_neck": true, + "use_syncbn": true + } +} diff --git a/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_xl.json b/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_xl.json new file mode 100644 index 000000000..3a01f5852 --- /dev/null +++ b/community/cv/yolo-world/yolow/model/model_cfgs/yoloworld_xl.json @@ -0,0 +1,37 @@ +{ + "yoloworld_data_preprocessor": { + "mean": [0, 0, 0], + "std": [255, 255, 255], + "bgr_to_rgb": true, + "non_blocking": true + }, + "yolov8_backbone": { + "last_stage_out_channels": 512, + "deepen_factor": 1.00, + "widen_factor": 1.50 + }, + "yoloworld_text": { + "model_name": "openai/clip-vit-base-patch32", + "frozen_modules": ["all"], + "channels": 512 + }, + "yoloworld_neck": { + "in_channels": [256, 512, 512], + "out_channels": [256, 512, 512], + "embed_channels": [128, 256, 256], + "num_heads": [4, 8, 8] + }, + "yoloworld_backbone": { + "with_text_model": true + }, + "yoloworld_head_module": { + "use_bn_head": true, + "use_einsum": true + }, + "yoloworld_detector": { + "num_train_classes": 80, + "num_test_classes": 1203, + "mm_neck": true, + "use_syncbn": true + } +} diff --git a/community/cv/yolo-world/yolow/model/task_utils/__init__.py b/community/cv/yolo-world/yolow/model/task_utils/__init__.py new file mode 100644 index 000000000..7b40c3018 --- /dev/null +++ b/community/cv/yolo-world/yolow/model/task_utils/__init__.py @@ -0,0 +1,11 @@ +""" +__init__.py +""" +from .distance_point_bbox_coder import DistancePointBBoxCoder +from .point_generator import MlvlPointGenerator + +__all__ = ( + 'MlvlPointGenerator', + 'DistancePointBBoxCoder', + # 'BatchTaskAlignedAssigner', +) diff --git a/community/cv/yolo-world/yolow/model/task_utils/distance_point_bbox_coder.py b/community/cv/yolo-world/yolow/model/task_utils/distance_point_bbox_coder.py new file mode 100644 index 000000000..5e44460ff --- /dev/null +++ b/community/cv/yolo-world/yolow/model/task_utils/distance_point_bbox_coder.py @@ -0,0 +1,108 @@ +""" +distance_point_bbox_coder +""" + +from typing import Optional, Sequence, Union + +import mindspore as ms +from mindspore import Tensor + +__all__ = ('DistancePointBBoxCoder',) + + +def ms_bbox2distance(points: Tensor, bbox: Tensor, max_dis: Optional[float] = None, eps: float = 0.1) -> Tensor: + """Decode bounding box based on distances. + """ + left = points[..., 0] - bbox[..., 0] + top = points[..., 1] - bbox[..., 1] + right = bbox[..., 2] - points[..., 0] + bottom = bbox[..., 3] - points[..., 1] + if max_dis is not None: + left = left.clamp(min=0, max=max_dis - eps) + top = top.clamp(min=0, max=max_dis - eps) + right = right.clamp(min=0, max=max_dis - eps) + bottom = bottom.clamp(min=0, max=max_dis - eps) + return ms.ops.stack([left, top, right, bottom], -1) + + +def ms_distance2bbox(points: Tensor, + distance: Tensor, + max_shape: Optional[Union[Sequence[int], Tensor, Sequence[Sequence[int]]]] = None) -> Tensor: + """Decode distance prediction to bounding box. + """ + + x1 = points[..., 0] - distance[..., 0] + y1 = points[..., 1] - distance[..., 1] + x2 = points[..., 0] + distance[..., 2] + y2 = points[..., 1] + distance[..., 3] + + bboxes = ms.ops.stack([x1, y1, x2, y2], -1) + + if max_shape is not None: + if bboxes.dim() == 2: + # speed up + bboxes[:, 0::2].clamp_(min=0, max=max_shape[1]) + bboxes[:, 1::2].clamp_(min=0, max=max_shape[0]) + return bboxes + + if not isinstance(max_shape, Tensor): + max_shape = x1.new_tensor(max_shape) + max_shape = max_shape[..., :2].type_as(x1) + if max_shape.ndim == 2: + assert bboxes.ndim == 3 + assert max_shape.size(0) == bboxes.size(0) + + min_xy = x1.new_tensor(0) + max_xy = ms.ops.cat([max_shape, max_shape], + axis=-1).flip(-1).unsqueeze(-2) + bboxes = ms.ops.where(bboxes < min_xy, min_xy, bboxes) + bboxes = ms.ops.where(bboxes > max_xy, max_xy, bboxes) + + return bboxes + + +class DistancePointBBoxCoder: + """Distance Point BBox coder. + + This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left, + right) and decode it back to the original. + """ + + # The size of the last of dimension of the encoded tensor. + encode_size = 4 + + def __init__(self, clip_border: Optional[bool] = True, use_box_type: bool = False) -> None: + self.use_box_type = use_box_type + self.clip_border = clip_border + + def encode( + self, + points: Tensor, + gt_bboxes: Tensor, # modified + max_dis: float = 16., + eps: float = 0.01) -> Tensor: + """Encode bounding box to distances. + """ + assert points.size(-2) == gt_bboxes.size(-2) + assert points.size(-1) == 2 + assert gt_bboxes.size(-1) == 4 + return ms_bbox2distance(points, gt_bboxes, max_dis, eps) + + def decode( + self, + points: Tensor, + pred_bboxes: Tensor, + stride: Tensor, # modified + max_shape: Optional[Union[Sequence[int], Tensor, Sequence[Sequence[int]]]] = None) -> Tensor: + """Decode distance prediction to bounding box. + """ + + assert points.shape[-2] == pred_bboxes.shape[-2] + assert points.shape[-1] == 2 + assert pred_bboxes.shape[-1] == 4 + if self.clip_border is False: + max_shape = None + + bboxes = ms_distance2bbox( + points, pred_bboxes * stride[None, :, None], max_shape) + return bboxes diff --git a/community/cv/yolo-world/yolow/model/task_utils/point_generator.py b/community/cv/yolo-world/yolow/model/task_utils/point_generator.py new file mode 100644 index 000000000..71624093e --- /dev/null +++ b/community/cv/yolo-world/yolow/model/task_utils/point_generator.py @@ -0,0 +1,229 @@ +""" +point_generator.py +""" +from typing import List, Tuple, Union +import numpy as np + + +import mindspore as ms +from mindspore import Tensor + +__all__ = ( + 'MlvlPointGenerator', +) + + +class MlvlPointGenerator: + """Standard points generator for multi-level (Mlvl) feature maps in 2D + points-based detectors. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels in order (w, h). + offset (float): The offset of points, the value is normalized with + corresponding stride. Defaults to 0.5. + """ + + def __init__(self, strides: Union[List[int], List[Tuple[int, int]]], offset: float = 0.5) -> None: + + self.strides = [(stride, stride) for stride in strides] + self.offset = offset + + @property + def num_levels(self) -> int: + """int: number of feature levels that the generator will be applied""" + return len(self.strides) + + @property + def num_base_priors(self) -> List[int]: + """list[int]: The number of priors (points) at a point + on the feature grid""" + return [1 for _ in range(len(self.strides))] + + def _meshgrid(self, x: Tensor, y: Tensor, row_major: bool = True) -> Tuple[Tensor, Tensor]: + # yy, xx = torch.meshgrid(y, x, indexing='ij') + yy, xx = ms.ops.meshgrid(y, x, indexing='ij') + if row_major: + # warning .flatten() would cause error in ONNX exporting + # have to use reshape here + return xx.reshape(-1), yy.reshape(-1) + return yy.reshape(-1), xx.reshape(-1) + + def grid_priors(self, + featmap_sizes: List[Tuple], + dtype: ms.dtype = ms.float32, + # device: DeviceType = 'cuda', + with_stride: bool = False) -> List[Tensor]: + """Generate grid points of multiple feature levels. + + Args: + featmap_sizes (list[tuple]): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + dtype (:obj:`dtype`): Dtype of priors. Defaults to torch.float32. + device (str | torch.device): The device where the anchors will be + put on. + with_stride (bool): Whether to concatenate the stride to + the last dimension of points. + + Return: + list[torch.Tensor]: Points of multiple feature levels. + The sizes of each tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + + assert self.num_levels == len(featmap_sizes) + multi_level_priors = [] + for i in range(self.num_levels): + priors = self.single_level_grid_priors( + featmap_sizes[i], level_idx=i, dtype=dtype, with_stride=with_stride) + multi_level_priors.append(priors) + return multi_level_priors + + def single_level_grid_priors(self, + featmap_size: Tuple[int], + level_idx: int, + dtype: ms.dtype = ms.float32, + # device: DeviceType = 'cuda', + with_stride: bool = False) -> Tensor: + """Generate grid Points of a single level. + + Note: + This function is usually called by method ``self.grid_priors``. + + Args: + featmap_size (tuple[int]): Size of the feature maps, arrange as + (h, w). + level_idx (int): The index of corresponding feature map level. + dtype (:obj:`dtype`): Dtype of priors. Defaults to torch.float32. + device (str | torch.device): The device the tensor will be put on. + Defaults to 'cuda'. + with_stride (bool): Concatenate the stride to the last dimension + of points. + + Return: + Tensor: Points of single feature levels. + The shape of tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + feat_h, feat_w = featmap_size + stride_w, stride_h = self.strides[level_idx] + shift_x = (ms.ops.arange(0, feat_w) + self.offset) * stride_w + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_x = shift_x.to(dtype) + + shift_y = (ms.ops.arange(0, feat_h) + self.offset) * stride_h + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_y = shift_y.to(dtype) + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + if not with_stride: + shifts = ms.ops.stack([shift_xx, shift_yy], axis=-1) + else: + # use `shape[0]` instead of `len(shift_xx)` for ONNX export + stride_w = shift_xx.new_full( + (shift_xx.shape[0],), stride_w).to(dtype) + stride_h = shift_xx.new_full( + (shift_yy.shape[0],), stride_h).to(dtype) + shifts = ms.ops.stack( + [shift_xx, shift_yy, stride_w, stride_h], axis=-1) + # all_points = shifts.to(device) + return shifts + + def valid_flags(self, + featmap_sizes: List[Tuple[int, int]], + pad_shape: Tuple[int],) -> List[Tensor]: + """Generate valid flags of points of multiple feature levels. + + Args: + featmap_sizes (list(tuple)): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + pad_shape (tuple(int)): The padded shape of the image, + arrange as (h, w). + device (str | torch.device): The device where the anchors will be + put on. + + Return: + list(torch.Tensor): Valid flags of points of multiple levels. + """ + assert self.num_levels == len(featmap_sizes) + multi_level_flags = [] + for i in range(self.num_levels): + point_stride = self.strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = pad_shape[:2] + valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h) + valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w) + flags = self.single_level_valid_flags( + (feat_h, feat_w), (valid_feat_h, valid_feat_w)) + multi_level_flags.append(flags) + return multi_level_flags + + def single_level_valid_flags(self, + featmap_size: Tuple[int, int], + valid_size: Tuple[int, int],) -> Tensor: + """Generate the valid flags of points of a single feature map. + + Args: + featmap_size (tuple[int]): The size of feature maps, arrange as + as (h, w). + valid_size (tuple[int]): The valid size of the feature maps. + The size arrange as as (h, w). + device (str | torch.device): The device where the flags will be + put on. Defaults to 'cuda'. + + Returns: + torch.Tensor: The valid flags of each points in a single level \ + feature map. + """ + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = ms.ops.zeros(feat_w, dtype=ms.bool_) + valid_y = ms.ops.zeros(feat_h, dtype=ms.bool_) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + return valid + + def sparse_priors(self, + prior_idxs: Tensor, + featmap_size: Tuple[int], + level_idx: int, + dtype: ms.dtype = ms.float32,) -> Tensor: + """Generate sparse points according to the ``prior_idxs``. + + Args: + prior_idxs (Tensor): The index of corresponding anchors + in the feature map. + featmap_size (tuple[int]): feature map size arrange as (w, h). + level_idx (int): The level index of corresponding feature + map. + dtype (obj:`torch.dtype`): Date type of points. Defaults to + ``torch.float32``. + device (str | torch.device): The device where the points is + located. + Returns: + Tensor: Anchor with shape (N, 2), N should be equal to + the length of ``prior_idxs``. And last dimension + 2 represent (coord_x, coord_y). + """ + height, width = featmap_size + x = (prior_idxs % width + self.offset) * self.strides[level_idx][0] + y = ((prior_idxs // width) % height + self.offset) * \ + self.strides[level_idx][1] + prioris = ms.ops.stack([x, y], axis=1).to(dtype) + return prioris diff --git a/community/cv/yolo-world/yolow/model/yolo_base.py b/community/cv/yolo-world/yolow/model/yolo_base.py new file mode 100644 index 000000000..bd5aa096f --- /dev/null +++ b/community/cv/yolo-world/yolow/model/yolo_base.py @@ -0,0 +1,147 @@ +""" +yolo_base.py +""" +from typing import Tuple + +import mindspore.nn as nn +from mindspore import Tensor +from .layers import Conv, CSPLayer, SPPFBottleneck +from .misc import make_divisible, make_round +__all__ = ('YOLOv8CSPDarknet',) + + +class YOLOv8CSPDarknet(nn.Cell): + """CSP-Darknet backbone used in YOLOv8. + """ + arch_settings = { # in_channels, out_channels, num_blocks, add_identity, use_spp + 'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False], + [256, 512, 6, True, False], [512, None, 3, True, True]], + } + + def __init__(self, + arch: str = 'P5', + last_stage_out_channels: int = 1024, + deepen_factor: float = 1.0, + widen_factor: float = 1.0, + input_channels: int = 3, + out_indices: Tuple[int] = (2, 3, 4), + frozen_stages: int = -1, + with_norm: bool = True, + with_activation: bool = True, + norm_eval: bool = False): + super().__init__() + self.arch_settings[arch][-1][1] = last_stage_out_channels + self.arch_settings = self.arch_settings[arch] + self.num_stages = len(self.arch_settings) + + assert set(out_indices).issubset( + i for i in range(len(self.arch_settings) + 1)) + + if frozen_stages not in range(-1, len(self.arch_settings) + 1): + raise ValueError('"frozen_stages" must be in range(-1, ' + 'len(arch_setting) + 1). But received ' + f'{frozen_stages}') + + self.input_channels = input_channels + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.widen_factor = widen_factor + self.deepen_factor = deepen_factor + self.norm_eval = norm_eval + self.with_norm = with_norm + self.with_activation = with_activation + + self.layers = [] + + # self.stem = self.build_stem_layer() + # self.layers.append(self.build_stem_layer()) + self.insert_child_to_cell('0', self.build_stem_layer()) + self.layers_name = ['0'] + + for idx, setting in enumerate(self.arch_settings): + stage = [] + stage += self.build_stage_layer(idx, setting) + + # self.layers.append(nn.SequentialCell(*stage)) + self.layers_name.append(f'{idx + 1}') + + self.insert_child_to_cell(f'{idx + 1}', nn.SequentialCell(*stage)) + + def build_stem_layer(self) -> nn.Cell: + """ + build_stem_layer + """ + stem_conv = Conv( + self.input_channels, + make_divisible(self.arch_settings[0][0], self.widen_factor), + kernel_size=3, + stride=2, + padding=1, + with_norm=self.with_norm, + with_activation=self.with_activation) + return stem_conv + + def build_stage_layer(self, idx: int, setting: list) -> list: + """ + build_stage_layer + """ + in_channels, out_channels, num_blocks, add_identity, use_spp = setting + + in_channels = make_divisible(in_channels, self.widen_factor) + out_channels = make_divisible(out_channels, self.widen_factor) + num_blocks = make_round(num_blocks, self.deepen_factor) + stage = [] + conv_layer = Conv( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + with_norm=self.with_norm, + with_activation=self.with_activation) + stage.append(conv_layer) + csp_layer = CSPLayer( + out_channels, + out_channels, + num_blocks=num_blocks, + add_identity=add_identity, + with_norm=self.with_norm, + with_activation=self.with_activation) + stage.append(csp_layer) + if use_spp: + spp = SPPFBottleneck( + out_channels, + out_channels, + kernel_sizes=5, + with_norm=self.with_norm, + with_activation=self.with_activation) + stage.append(spp) + return stage + + def init_weights(self): + from mindspore.common.initializer import initializer, HeNormal + for m in self.modules(): + if isinstance(m, nn.Conv2d): + m.weight.set_data(initializer(HeNormal(negative_slope=0, mode='fan_out', nonlinearity='relu'), + m.weight.shape, m.weight.dtype)) + if m.bias is not None: + m.bias.set_data(initializer('zeros', m.bias.shape)) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for i in range(self.frozen_stages + 1): + m = getattr(self, self.layers[i]) + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def construct(self, x: Tensor) -> tuple: + outs = [] + for i, layer_name in enumerate(self.layers_name): + layer = getattr(self, layer_name) + x = layer(x) + + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/community/cv/yolo-world/yolow/model/yolo_world.py b/community/cv/yolo-world/yolow/model/yolo_world.py new file mode 100644 index 000000000..0e84cf37e --- /dev/null +++ b/community/cv/yolo-world/yolow/model/yolo_world.py @@ -0,0 +1,144 @@ +""" +yolo_world.py +""" +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple, Union +import time +import mindspore.nn as nn +from mindspore import Tensor + + +__all__ = ('YOLOWorldDetector',) + +class YOLOWorldDetector(nn.Cell): + """YOLO-World arch + + train_step(): forward() -> loss() -> extract_feat() + val_step(): forward() -> predict() -> extract_feat() + """ + + def __init__(self, + backbone: nn.Cell, + neck: nn.Cell, + bbox_head: nn.Cell, + mm_neck: bool = False, + num_train_classes: int = 80, + num_test_classes: int = 80, + data_preprocessor: Optional[nn.Cell] = None,) -> None: + super().__init__() + + self.mm_neck = mm_neck + self.num_train_classes = num_train_classes + self.num_test_classes = num_test_classes + + self.backbone = backbone + self.neck = neck + self.bbox_head = bbox_head + self.data_preprocessor = data_preprocessor + + + @property + def with_neck(self) -> bool: + return hasattr(self, 'neck') and self.neck is not None + + def val_step(self, data: Union[tuple, dict, list]) -> list: + data = self.data_preprocessor(data, False) + + return self(**data, mode='predict') # type: ignore + + def test_step(self, data: Union[dict, tuple, list]) -> list: + return self.val_step(data) + + def construct(self, data: Union[dict, tuple, list]) -> Union[dict, list, tuple, Tensor]: + data_info = self.data_preprocessor(data, False) + res, time_dict = self.predict(data_info["inputs"], data_info["data_samples"]) + return res, time_dict + + + def parse_losses(self, losses: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]: + """ + parse_losses + """ + log_vars = [] + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, Tensor): + log_vars.append([loss_name, loss_value.mean()]) + elif isinstance(loss_value, Union[List[Tensor], Tuple[Tensor]]): + log_vars.append([loss_name, sum(_loss.mean() for _loss in loss_value)]) + else: + raise TypeError(f'{loss_name} is not a tensor or list of tensors') + + loss = sum(value for key, value in log_vars if 'loss' in key) + log_vars.insert(0, ['loss', loss]) + log_vars = OrderedDict(log_vars) # type: ignore + + return loss, log_vars # type: ignore + + + def predict(self, batch_inputs: Tensor, batch_data_samples: Union[List, dict], rescale: bool = True) -> list: + """ + predict + """ + start = time.perf_counter() + img_feats, txt_feats, time_dict = self.extract_feat(batch_inputs, batch_data_samples) + enc_time = time.perf_counter() + self.bbox_head.num_classes = txt_feats[0].shape[0] + # results_list = self.bbox_head.predict(img_feats, txt_feats, batch_data_samples, rescale=rescale) + results_list = self.bbox_head(img_feats, txt_feats, batch_data_samples, rescale=rescale) + pred_time = time.perf_counter() + + # print(f'enc_time: {(enc_time - start):.4f}\n', + # f'pred_time: {(pred_time - enc_time):.4f}\n', + # '_'*20) + + batch_data_samples = self.add_pred_to_datasample(batch_data_samples, results_list) + time_dict.update({"pred_time": pred_time-enc_time, "all_time": pred_time-start,}) + return batch_data_samples, time_dict + + def _forward(self, + batch_inputs: Tensor, + batch_data_samples: Optional[Union[List, dict]] = None) -> Tuple[List[Tensor]]: + img_feats, txt_feats = self.extract_feat(batch_inputs, batch_data_samples) + results = self.bbox_head.forward(img_feats, txt_feats) + return results + + def extract_feat(self, batch_inputs: Tensor, batch_data_samples: Union[List, dict]) -> Tuple[Tuple[Tensor], Tensor]: + """ + extract_feat + """ + txt_feats = None + if batch_data_samples is None: + texts = self.texts + txt_feats = self.text_feats + elif isinstance(batch_data_samples, dict) and 'texts' in batch_data_samples['img_metas']: + texts = batch_data_samples['img_metas']['texts'] + elif isinstance(batch_data_samples, list) and ('texts' in batch_data_samples[0]['img_metas']): + texts = [data_sample['img_metas']['texts'] for data_sample in batch_data_samples] + elif hasattr(self, 'text_feats'): + texts = self.texts + txt_feats = self.text_feats + else: + raise TypeError('batch_data_samples should be dict or list.') + if txt_feats is not None: + # forward image only + img_feats = self.backbone.forward_image(batch_inputs) + else: + img_feats, txt_feats, time_dict = self.backbone(batch_inputs, texts) + + if self.with_neck: + if self.mm_neck: + img_feats = self.neck(img_feats, txt_feats) + else: + img_feats = self.neck(img_feats) + return img_feats, txt_feats, time_dict + + def add_pred_to_datasample(self, data_samples: List, results_list: List) -> List: + for data_sample, pred_instances in zip(data_samples, results_list): + data_sample['pred_instances'] = pred_instances + # samplelist_boxtype2tensor(data_samples) + return data_samples + + def reparameterize(self, texts: List[List[str]]) -> None: + # encode text embeddings into the detector + self.texts = texts + self.text_feats = self.backbone.forward_text(texts) diff --git a/community/cv/yolo-world/yolow/model/yolo_world_backbone.py b/community/cv/yolo-world/yolow/model/yolo_world_backbone.py new file mode 100644 index 000000000..6562d0466 --- /dev/null +++ b/community/cv/yolo-world/yolow/model/yolo_world_backbone.py @@ -0,0 +1,67 @@ +""" +yolo_world_backbone.py +""" +from typing import List, Tuple +import time +import mindspore.nn as nn +from mindspore import Tensor + +__all__ = ("MultiModalYOLOBackbone",) + +class MultiModalYOLOBackbone(nn.Cell): + """ + MultiModalYOLOBackbone + """ + + def __init__(self, + image_model: nn.Cell, + text_model: nn.Cell, + frozen_stages: int = -1, + with_text_model: bool = True) -> None: + super().__init__() + self.with_text_model = with_text_model + self.image_model = image_model + if self.with_text_model: + self.text_model = text_model + else: + self.text_model = None + self.frozen_stages = frozen_stages + self._freeze_stages() + + def _freeze_stages(self): + """Freeze the parameters of the specified stage so that they are no + longer updated.""" + if self.frozen_stages >= 0: + for i in range(self.frozen_stages + 1): + m = getattr(self.image_model, self.image_model.layers[i]) + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode: bool = True): + super().train(mode) + self._freeze_stages() + + def construct(self, image: Tensor, text: List[List[str]]) -> Tuple[Tuple[Tensor], Tensor]: + """ + construct + """ + start = time.perf_counter() + img_feat = self.image_model(image) + img_enc_time = time.perf_counter() + txt_feat = self.text_model(text) + txt_enc_time = time.perf_counter() + + # print('_'*20+'\n', + # f'img_enc_time: {(img_enc_time - start):.4f}\n' + # f'txt_enc_time: {(txt_enc_time - img_enc_time):.4f}') + + return img_feat, txt_feat, {'img_enc_time': img_enc_time - start, 'txt_enc_time': txt_enc_time - img_enc_time} + + def forward_text(self, text: List[List[str]]) -> Tensor: + if self.with_text_model: + return self.text_model(text) + return None # no text model + + def forward_image(self, image: Tensor) -> Tuple[Tensor]: + return self.image_model(image) diff --git a/community/cv/yolo-world/yolow/model/yolo_world_head.py b/community/cv/yolo-world/yolow/model/yolo_world_head.py new file mode 100644 index 000000000..2f3dc0412 --- /dev/null +++ b/community/cv/yolo-world/yolow/model/yolo_world_head.py @@ -0,0 +1,502 @@ +""" +yolo_world_head.py +""" +import copy +from typing import List, Optional, Sequence, Tuple, Union +import numpy as np + +import mindspore.nn as nn +import mindspore as ms +from mindspore import Tensor +from .layers import Conv + +from .misc import (get_box_wh, make_divisible, multi_apply, yolow_dict, + ms_filter_scores_and_topk, ms_scale_boxes, ms_batched_nms) +from .task_utils import MlvlPointGenerator, DistancePointBBoxCoder + + +__all__ = ( + 'YOLOWorldHeadModule', + 'YOLOWorldHead', +) + + +class ContrastiveHead(nn.Cell): + """Contrastive Head for YOLO-World + """ + + def __init__(self, use_einsum: bool = True) -> None: + super().__init__() + self.bias = ms.Parameter(ms.ops.zeros([])) + self.logit_scale = ms.Parameter(ms.ops.ones([]) * np.log(1 / 0.07)) + self.use_einsum = use_einsum + + def construct(self, x: Tensor, w: Tensor) -> Tensor: + """Forward function of contrastive learning.""" + x = ms.ops.L2Normalize(x, axis=1) + w = ms.ops.L2Normalize(w, axis=-1) + + if self.use_einsum: + x = ms.ops.einsum('bchw,bkc->bkhw', x, w) + else: + batch, channel, height, width = x.shape + _, k, _ = w.shape + x = x.permute(0, 2, 3, 1) # bchw->bhwc + x = x.reshape(batch, -1, channel) # bhwc->b(hw)c + w = w.permute(0, 2, 1) # bkc->bck + x = ms.ops.matmul(x, w) + x = x.reshape(batch, height, width, k) + x = x.permute(0, 3, 1, 2) + + x = x * self.logit_scale.exp() + self.bias + return x + + +class BNContrastiveHead(nn.Cell): + """ Batch Norm Contrastive Head for YOLO-World + using batch norm instead of l2-normalization + """ + + def __init__(self, embed_dims: int, use_einsum: bool = True) -> None: + super().__init__() + self.norm = nn.BatchNorm2d(embed_dims, momentum=0.97, eps=0.001) + self.bias = ms.Parameter(ms.ops.zeros([])) + # use -1.0 is more stable + self.logit_scale = ms.Parameter(-1.0 * ms.ops.ones([])) + self.use_einsum = use_einsum + + def construct(self, x: Tensor, w: Tensor) -> Tensor: + """Forward function of contrastive learning.""" + x = self.norm(x) + l2_normalize = ms.ops.L2Normalize(axis=-1) + w = l2_normalize(w) + + if self.use_einsum: + x = ms.ops.einsum('bchw,bkc->bkhw', x, w) + else: + batch, channel, height, width = x.shape + _, k, _ = w.shape + x = x.permute(0, 2, 3, 1) # bchw->bhwc + x = x.reshape(batch, -1, channel) # bhwc->b(hw)c + w = w.permute(0, 2, 1) # bkc->bck + x = ms.ops.matmul(x, w) + x = x.reshape(batch, height, width, k) + x = x.permute(0, 3, 1, 2) + + x = x * self.logit_scale.exp() + self.bias + return x + + +class YOLOWorldHeadModule(nn.Cell): + """Head Module for YOLO-World + """ + + def __init__(self, + num_classes: int, + in_channels: Union[int, Sequence], + embed_dims: int, + use_bn_head: bool = True, + use_einsum: bool = True, + freeze_all: bool = False, + widen_factor: float = 1.0, + num_base_priors: int = 1, + featmap_strides: Sequence[int] = (8, 16, 32), + reg_max: int = 16, + with_norm: bool = True, + with_activation: bool = True) -> None: + super().__init__() + + self.embed_dims = embed_dims + self.use_bn_head = use_bn_head + self.use_einsum = use_einsum + self.freeze_all = freeze_all + self.num_classes = num_classes + self.featmap_strides = featmap_strides + self.num_levels = len(self.featmap_strides) + self.num_base_priors = num_base_priors + self.with_norm = with_norm + self.with_activation = with_activation + self.in_channels = in_channels + self.reg_max = reg_max + + in_channels = [] + for channel in self.in_channels: + channel = make_divisible(channel, widen_factor) + in_channels.append(channel) + self.in_channels = in_channels + + self._init_layers() + +# TODO: init_weights in mindspore + # def init_weights(self, prior_prob=0.01): + # """Initialize the weight and bias of PPYOLOE head.""" + # for reg_pred, cls_pred, cls_contrast, stride in zip(self.reg_preds, self.cls_preds, self.cls_contrasts, + # self.featmap_strides): + # reg_pred[-1].bias.data[:] = 1.0 # box + # cls_pred[-1].bias.data[:] = 0.0 # reset bias + # if hasattr(cls_contrast, 'bias'): + # nn.init.constant_(cls_contrast.bias.data, math.log(5 / self.num_classes / (640 / stride)**2)) + + def _init_layers(self) -> None: + """initialize conv layers in YOLOv8 head.""" + # Init decouple head + self.cls_preds = nn.CellList(auto_prefix=False) + self.reg_preds = nn.CellList(auto_prefix=False) + self.cls_contrasts = nn.CellList(auto_prefix=False) # 避免自动获取前缀命名 + + reg_out_channels = max( + (16, self.in_channels[0] // 4, self.reg_max * 4)) + cls_out_channels = max(self.in_channels[0], self.num_classes) + + for i in range(self.num_levels): + self.reg_preds.append( + nn.SequentialCell( + Conv( + in_channels=self.in_channels[i], + out_channels=reg_out_channels, + kernel_size=3, + stride=1, + padding=1, + with_norm=self.with_norm, + with_activation=self.with_activation), + Conv( + in_channels=reg_out_channels, + out_channels=reg_out_channels, + kernel_size=3, + stride=1, + padding=1, + with_norm=self.with_norm, + with_activation=self.with_activation), + nn.Conv2d(in_channels=reg_out_channels, out_channels=4 * self.reg_max, + kernel_size=1, has_bias=True, pad_mode='valid'))) + + self.cls_preds.append( + nn.SequentialCell( + Conv( + in_channels=self.in_channels[i], + out_channels=cls_out_channels, + kernel_size=3, + stride=1, + padding=1, + with_norm=self.with_norm, + with_activation=self.with_activation), + Conv( + in_channels=cls_out_channels, + out_channels=cls_out_channels, + kernel_size=3, + stride=1, + padding=1, + with_norm=self.with_norm, + with_activation=self.with_activation), + nn.Conv2d(in_channels=cls_out_channels, out_channels=self.embed_dims, + kernel_size=1, has_bias=True, pad_mode='valid'))) + if self.use_bn_head: + self.cls_contrasts.append(BNContrastiveHead( + self.embed_dims, self.use_einsum)) + else: + self.cls_contrasts.append(ContrastiveHead(self.use_einsum)) + + # proj = ms.ops.arange(self.reg_max, dtype=torch.float) + # self.register_buffer('proj', proj, persistent=False) + self.proj = ms.Parameter(Tensor(ms.ops.arange( + self.reg_max, dtype=ms.float32)), requires_grad=False) + + def train(self, mode=True): + super().train(mode) + if self.freeze_all: + self._freeze_all() + + def construct(self, img_feats: Tuple[Tensor], txt_feats: Tensor) -> Tuple[List]: + """Forward features from the upstream network.""" + assert len(img_feats) == self.num_levels + txt_feats = [txt_feats for _ in range(self.num_levels)] + + res = multi_apply(self.forward_single, img_feats, txt_feats, self.cls_preds, self.reg_preds, + self.cls_contrasts) + + return res + + def forward_single(self, img_feat: Tensor, txt_feat: Tensor, cls_pred: nn.CellList, reg_pred: nn.CellList, + cls_contrast: nn.CellList) -> Tuple: + """Forward feature of a single scale level.""" + b, _, h, w = img_feat.shape + cls_embed = cls_pred(img_feat) + cls_logit = cls_contrast(cls_embed, txt_feat) + bbox_dist_preds = reg_pred(img_feat) + if self.reg_max > 1: + bbox_dist_preds = bbox_dist_preds.reshape( + [-1, 4, self.reg_max, h * w]).permute(0, 3, 1, 2) + + # TODO: The get_flops script cannot handle the situation of + # matmul, and needs to be fixed later + bbox_preds = ms.ops.matmul(ms.ops.softmax( + bbox_dist_preds, axis=3), self.proj.unsqueeze(1)).squeeze(-1) + bbox_preds = ms.ops.swapaxes(bbox_preds, 1, 2).reshape(b, -1, h, w) + else: + bbox_preds = bbox_dist_preds + if self.training: + return cls_logit, bbox_preds, bbox_dist_preds + return cls_logit, bbox_preds + + +class YOLOWorldHead(nn.Cell): + """YOLO-World Head + + - loss(): forward() -> loss_by_feat() + - predict(): forward() -> predict_by_feat() + - loss_and_predict(): forward() -> loss_by_feat() -> predict_by_feat() + """ + + def __init__(self, head_module: nn.Cell, test_cfg: Optional[dict] = None) -> None: + super().__init__() + + self.head_module = head_module + self.num_classes = self.head_module.num_classes + self.featmap_strides = self.head_module.featmap_strides + self.num_levels = len(self.featmap_strides) + + self.test_cfg = test_cfg + + # init task_utils + self.prior_generator = MlvlPointGenerator( + offset=0.5, strides=[8, 16, 32]) + self.bbox_coder = DistancePointBBoxCoder() + self.num_base_priors = self.prior_generator.num_base_priors[0] + # TODO later 0722 + self.featmap_sizes = [ms.numpy.empty([1])] * self.num_levels + + self.prior_match_thr = 4.0 + self.near_neighbor_thr = 0.5 + self.obj_level_weights = [4.0, 1.0, 0.4] + self.ignore_iof_thr = -1.0 + + # Add common attributes to reduce calculation + self.featmap_sizes_train = None + self.num_level_priors = None + self.flatten_priors_train = None + self.stride_tensor = None + + def construct(self, + img_feats: Tuple[Tensor], + txt_feats: Tensor, + batch_data_samples: Union[list, dict], + rescale: bool = False) -> list: + """Perform forward propagation of the detection head and predict + detection results on the features of the upstream network. + """ + batch_img_metas = [ + # changed `.metainfo` to ['img_metas'] + data_samples['img_metas'] for data_samples in batch_data_samples + ] + outs = self.head_module(img_feats, txt_feats) + + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_img_metas, rescale=rescale) + return predictions + + def predict(self, + img_feats: Tuple[Tensor], + txt_feats: Tensor, + batch_data_samples: Union[list, dict], + rescale: bool = False) -> list: + """Perform forward propagation of the detection head and predict + detection results on the features of the upstream network. + """ + batch_img_metas = [ + # changed `.metainfo` to ['img_metas'] + data_samples['img_metas'] for data_samples in batch_data_samples + ] + outs = self(img_feats, txt_feats) + + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_img_metas, rescale=rescale) + return predictions + + def predict_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + objectnesses: Optional[List[Tensor]] = None, + batch_img_metas: Optional[List[dict]] = None, + cfg: Optional[dict] = None, + rescale: bool = True, + with_nms: bool = True) -> List: + """Transform a batch of output features extracted by the head into + bbox results. + """ + assert len(cls_scores) == len(bbox_preds) + if objectnesses is None: + with_objectnesses = False + else: + with_objectnesses = True + assert len(cls_scores) == len(objectnesses) + + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + + multi_label = cfg.multi_label + multi_label &= self.num_classes > 1 + cfg.multi_label = multi_label + + num_imgs = len(batch_img_metas) + featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] + + # If the shape does not change, use the previous mlvl_priors + if featmap_sizes != self.featmap_sizes: + self.mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, dtype=cls_scores[0].dtype) + # featmap_sizes = [ms.ops.full_like(f) for f in featmap_sizes] + self.featmap_sizes = featmap_sizes + + flatten_priors = ms.ops.cat(self.mlvl_priors) + mlvl_strides = [] + # TODO: mindspore 获取Tensor的shape 比如 [160, 160],统计numel的时候返回2,不像torch 返回160*160 + for featmap_size, stride in zip(featmap_sizes, self.featmap_strides): + tmp = ms.ops.full( + (featmap_size[0]*featmap_size[1] * self.num_base_priors,), stride) + mlvl_strides.append(tmp) + # mlvl_strides = [ + # ms.ops.full((ms.ops.numel(featmap_size) * self.num_base_priors, ), stride) + # # flatten_priors.new_full((featmap_size.numel() * self.num_base_priors, ), stride) + # for featmap_size, stride in zip(featmap_sizes, self.featmap_strides) + # ] + + flatten_stride = ms.ops.cat(mlvl_strides) + + # flatten cls_scores, bbox_preds and objectness + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes) for cls_score in cls_scores + ] + flatten_bbox_preds = [bbox_pred.permute(0, 2, 3, 1).reshape( + num_imgs, -1, 4) for bbox_pred in bbox_preds] + + flatten_cls_scores = ms.ops.cat(flatten_cls_scores, axis=1).sigmoid() + flatten_bbox_preds = ms.ops.cat(flatten_bbox_preds, axis=1) + + flatten_decoded_bboxes = self.bbox_coder.decode( + flatten_priors[None], flatten_bbox_preds, flatten_stride) + + if with_objectnesses: + flatten_objectness = [objectness.permute(0, 2, 3, 1).reshape( + num_imgs, -1) for objectness in objectnesses] + flatten_objectness = ms.ops.cat( + flatten_objectness, axis=1).sigmoid() + else: + flatten_objectness = [None for _ in range(num_imgs)] + results_list = [] + + for (bboxes, scores, objectness, img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores, + flatten_objectness, batch_img_metas): + ori_shape = img_meta['ori_shape'] + scale_factor = img_meta['scale_factor'] + if 'pad_param' in img_meta: + pad_param = img_meta['pad_param'] + else: + pad_param = None + + score_thr = cfg.get('score_thr', -1) + # yolox_style does not require the following operations + if objectness is not None and score_thr > 0 and not cfg.get('yolox_style', False): + conf_inds = objectness > score_thr + bboxes = bboxes[conf_inds, :] + scores = scores[conf_inds, :] + objectness = objectness[conf_inds] + + if objectness is not None: + # conf = obj_conf * cls_conf + scores *= objectness[:, None] + + if scores.shape[0] == 0: + empty_results = yolow_dict() + empty_results.bboxes = bboxes + empty_results.scores = scores[:, 0] + empty_results.labels = scores[:, 0].int() + results_list.append(empty_results) + continue + + nms_pre = cfg.get('nms_pre', 10000) + + if cfg.multi_label is False: + scores, labels = scores.max(1, keepdim=True) + scores, _, keep_idxs, results = ms_filter_scores_and_topk( + scores, score_thr, nms_pre, results=dict(labels=labels[:, 0])) + labels = results['labels'] + else: + scores, labels, keep_idxs, _ = ms_filter_scores_and_topk( + scores, score_thr, nms_pre) + + results = yolow_dict( + scores=scores, labels=labels, bboxes=bboxes[keep_idxs]) + + if rescale: + if pad_param is not None: + results.bboxes -= ms.Tensor([pad_param[2], + pad_param[0], pad_param[2], pad_param[0]]) + + results.bboxes /= ms.Tensor( + scale_factor).repeat(2).unsqueeze(0) + + if cfg.get('yolox_style', False): + # do not need max_per_img + cfg.max_per_img = len(results) + + results = self._bbox_post_process( + results=results, cfg=cfg, rescale=False, with_nms=with_nms, img_meta=img_meta) + + results.bboxes[:, 0::2] = ms.ops.clamp( + results.bboxes[:, 0::2], 0, ori_shape[1]) + results.bboxes[:, 1::2] = ms.ops.clamp( + results.bboxes[:, 1::2], 0, ori_shape[0]) + + results_list.append(results) + + return results_list + + def _bbox_post_process(self, + results: dict, + cfg: dict, + rescale: bool = False, + with_nms: bool = True, + img_meta: Optional[dict] = None) -> dict: + """bbox post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. Usually `with_nms` is False is used for aug test. + """ + if rescale: + assert img_meta.get('scale_factor') is not None + scale_factor = [1 / s for s in img_meta['scale_factor']] + results.bboxes = ms_scale_boxes(results.bboxes, scale_factor) + + if hasattr(results, 'score_factors'): + # TODO: Add sqrt operation in order to be consistent with + # the paper. + score_factors = results.pop('score_factors') + results.scores = results.scores * score_factors + + # filter small size bboxes + if cfg.get('min_bbox_size', -1) >= 0: + w, h = get_box_wh(results.bboxes) + valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) + if not valid_mask.all(): + results = results[valid_mask] + + # TODO: deal with `with_nms` and `nms_cfg=None` in test_cfg + + if with_nms and results.bboxes.numel() > 0: + # bboxes = get_box_tensor(results.bboxes) + bboxes = results.bboxes + assert isinstance(bboxes, Tensor) + # start = time.perf_counter() + det_bboxes, keep_idxs = ms_batched_nms( + bboxes, results.scores, results.labels, cfg.nms) + # topk_time = time.perf_counter() + # print(f"ms_batched_nms time: {topk_time - start}") + # results = results[keep_idxs] + for k in results.keys(): + results[k] = results[k][keep_idxs] + # some nms would reweight the score, such as softnms + results.scores = det_bboxes[:, -1] + for k in results.keys(): + results[k] = results[k][:cfg.max_per_img] + + return results diff --git a/community/cv/yolo-world/yolow/model/yolo_world_pafpn.py b/community/cv/yolo-world/yolow/model/yolo_world_pafpn.py new file mode 100644 index 000000000..68404caec --- /dev/null +++ b/community/cv/yolo-world/yolow/model/yolo_world_pafpn.py @@ -0,0 +1,187 @@ +""" +yolo_world_pafpn.py +""" +from typing import List, Union + +import mindspore.nn as nn +from mindspore import Tensor +import mindspore as ms +from .layers import Conv, MaxSigmoidCSPLayer +from .misc import make_divisible, make_round +__all__ = ('YOLOWorldPAFPN',) + + +class YOLOWorldPAFPN(nn.Cell): + """Path Aggregation Network used in YOLO World + Following YOLOv8 PAFPN, including text to image fusion + """ + + def __init__(self, + in_channels: List[int], + out_channels: Union[List[int], int], + guide_channels: int, + embed_channels: List[int], + num_heads: List[int], + deepen_factor: float = 1.0, + widen_factor: float = 1.0, + num_csp_blocks: int = 3, + freeze_all: bool = False, + with_norm: bool = True, + with_activation: bool = True) -> None: + super().__init__() + self.guide_channels = guide_channels + self.embed_channels = embed_channels + self.num_heads = num_heads + self.num_csp_blocks = num_csp_blocks + + self.in_channels = in_channels + self.out_channels = out_channels + self.deepen_factor = deepen_factor + self.widen_factor = widen_factor + self.upsample_feats_cat_first = True + self.freeze_all = freeze_all + self.with_norm = with_norm + self.with_activation = with_activation + + self.reduce_layers = nn.CellList() + for idx in range(len(in_channels)): + self.reduce_layers.append(self.build_reduce_layer(idx)) + + # build top-down blocks + self.upsample_layers = nn.CellList() + self.top_down_layers = nn.CellList() + for idx in range(len(in_channels) - 1, 0, -1): + self.upsample_layers.append(self.build_upsample_layer(idx)) + self.top_down_layers.append(self.build_top_down_layer(idx)) + + # build bottom-up blocks + self.downsample_layers = nn.CellList() + self.bottom_up_layers = nn.CellList() + for idx in range(len(in_channels) - 1): + self.downsample_layers.append(self.build_downsample_layer(idx)) + self.bottom_up_layers.append(self.build_bottom_up_layer(idx)) + + self.out_layers = nn.CellList() + for idx in range(len(in_channels)): + self.out_layers.append(self.build_out_layer(idx)) + + # TODO:暂时没有完成freeze 的写法 + def _freeze_all(self): + for m in self.modules(): + # if isinstance(m, nn.modules.batchnorm): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super().train(mode) + if self.freeze_all: + self._freeze_all() + + def init_weights(self): + """init_weights""" + from mindspore.common.initializer import initializer, HeNormal + for m in self.cells(): + if isinstance(m, nn.Conv2d): + # In order to be consistent with the source code, + # reset the Conv2d initialization parameters + m.weight.set_data(initializer(HeNormal(negative_slope=0, mode='fan_out', nonlinearity='relu'), + m.weight.shape, m.weight.dtype)) + if hasattr(m, 'bias') and m.bias is not None: + m.bias.set_data(initializer( + "zeros", m.bias.shape, m.bias.dtype)) + + def build_reduce_layer(self, idx: int) -> nn.Cell: + return nn.Identity() + + def build_upsample_layer(self, *args, **kwargs) -> nn.Cell: + return nn.Upsample(scale_factor=2.0, mode='nearest', recompute_scale_factor=True) + + def build_downsample_layer(self, idx: int) -> nn.Cell: + return Conv( + make_divisible(self.in_channels[idx], self.widen_factor), + make_divisible(self.in_channels[idx], self.widen_factor), + kernel_size=3, + stride=2, + padding=1, + with_norm=self.with_norm, + with_activation=self.with_activation) + + def build_out_layer(self, *args, **kwargs) -> nn.Cell: + return nn.Identity() + + def build_top_down_layer(self, idx: int) -> nn.Cell: + return MaxSigmoidCSPLayer( + in_channels=make_divisible( + (self.in_channels[idx - 1] + self.in_channels[idx]), self.widen_factor), + out_channels=make_divisible( + self.out_channels[idx - 1], self.widen_factor), + guide_channels=self.guide_channels, + embed_channels=make_round( + self.embed_channels[idx - 1], self.widen_factor), + num_heads=make_round(self.num_heads[idx - 1], self.widen_factor), + num_blocks=make_round(self.num_csp_blocks, self.deepen_factor), + add_identity=False, + with_norm=self.with_norm, + with_activation=self.with_activation) + + def build_bottom_up_layer(self, idx: int) -> nn.Cell: + return MaxSigmoidCSPLayer( + in_channels=make_divisible( + (self.out_channels[idx] + self.out_channels[idx + 1]), self.widen_factor), + out_channels=make_divisible( + self.out_channels[idx + 1], self.widen_factor), + guide_channels=self.guide_channels, + embed_channels=make_round( + self.embed_channels[idx + 1], self.widen_factor), + num_heads=make_round(self.num_heads[idx + 1], self.widen_factor), + num_blocks=make_round(self.num_csp_blocks, self.deepen_factor), + add_identity=False, + with_norm=self.with_norm, + with_activation=self.with_activation) + + def construct(self, img_feats: List[Tensor], txt_feats: Tensor) -> tuple: + """Forward function. + including multi-level image features, text features: BxLxD + """ + assert len(img_feats) == len(self.in_channels) + # reduce layers + reduce_outs = [] + for idx in range(len(self.in_channels)): + reduce_outs.append(self.reduce_layers[idx](img_feats[idx])) + + # top-down path + inner_outs = [reduce_outs[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_high = inner_outs[0] + feat_low = reduce_outs[idx - 1] + upsample_feat = self.upsample_layers[len( + self.in_channels) - 1 - idx](feat_high) + if self.upsample_feats_cat_first: + top_down_layer_inputs = ms.ops.cat( + [upsample_feat, feat_low], 1) + else: + top_down_layer_inputs = ms.ops.cat( + [feat_low, upsample_feat], 1) + + inner_out = self.top_down_layers[len( + self.in_channels) - 1 - idx](top_down_layer_inputs, txt_feats) + inner_outs.insert(0, inner_out) + + # bottom-up path + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_high = inner_outs[idx + 1] + downsample_feat = self.downsample_layers[idx](feat_low) + out = self.bottom_up_layers[idx](ms.ops.cat( + [downsample_feat, feat_high], 1), txt_feats) + outs.append(out) + + # out_layers + results = [] + for idx in range(len(self.in_channels)): + results.append(self.out_layers[idx](outs[idx])) + + return tuple(results) -- Gitee