1 Star 0 Fork 0

yasuo_hao/S-DCNet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
labels_counts_utils.py 4.02 KB
一键复制 编辑 原始数据 按行查看 历史
dmburd 提交于 2020-03-24 14:35 . Initial set of files
import sys
import os
import os.path
from sys import exit as e
from bisect import bisect_left
import numpy as np
import torch
import q
def apply_label2count(cls_labels_tensor, cls_label2count_tensor):
"""
Function for obtaining a tensor containing count values
from a tensor containing class labels.
Implementation details: torch.index_select() is applied to the flattened
versions of the tensors.
Args:
cls_labels_tensor: Tensor (of arbitrary shape in general)
containing class labels (integers).
cls_label2count_tensor: Tensor containing 1-to-1 mapping
between a scalar label (integer) to a scalar count value (float).
Returns:
Tensor containing count values (instead of labels).
It has the same shape as `cls_labels_tensor`.
"""
orig_shape = cls_labels_tensor.shape
t = torch.index_select(
cls_label2count_tensor.cuda(), # input
dim=0,
index=cls_labels_tensor.reshape((-1,)).cuda()
)
# ^ DO NOT specify the 1st argument as input=<smth>!
# TorchScript will throw `RuntimeError: Arguments for call are not valid`.
# aten::index_select(Tensor self, int dim, Tensor index) -> (Tensor):
# Argument self not provided.
return t.reshape(orig_shape)
def apply_count2label(counts_tensor, interval_bounds):
"""
Function for obtaining a tensor containing class labels
from a tensor containing count values (inverse to apply_label2count()).
Implementation details: bisect.bisect_left() is called on the sorted
interval bounds (for count values) and the passed count values.
Args:
counts_tensor: Tensor containing count values (floats).
interval_bounds: Interval boundaries for the count values (floats).
Returns:
Tensor containing class labels (instead count values).
It has the same shape as `counts_tensor`.
"""
orig_shape = counts_tensor.shape
labels_list = []
for c in counts_tensor.reshape((-1,)).tolist():
labels_list.append(bisect_left(interval_bounds, c))
result = np.array(labels_list).reshape(orig_shape)
return torch.from_numpy(result)
def make_label2count_list(args_dict):
"""
Construct the mapping between the class labels (int) and count values
(float).
Interval boundaries are the base for both class labels and count values.
Class labels are simply consecutive indices (zero-based) of the adjacent
intervals. Count values are middle points of the intervals (except for the
rightmost interval which is semi-open [C, +inf); the left boundary C
is chosen as the count value in this case).
Args:
args_dict: Dictionary containing required configuration values.
The keys required for this function are 'num_intervals',
'interval_step', 'partition_method'.
Returns:
Interval boundaries; list with the count values (their indices are
the class labels).
"""
s = args_dict['interval_step']
Cmax = args_dict['num_intervals']
numpoints = int((0.45 - 0.05) / 0.05) + 1
add_for_two_linear = np.array([])
if args_dict['partition_method'] == 2:
add_for_two_linear = np.linspace(0.05, 0.45, numpoints)
numpoints = int((Cmax - s) / s) + 1
bounds = np.linspace(s, Cmax, numpoints)
very_1st_bnd = np.array([1e-6, ])
interval_bounds = np.concatenate(
[very_1st_bnd, add_for_two_linear, bounds])
# tranform interval endpoints to count values
bnds = interval_bounds.tolist()
label2count_list = [0.0, ]
# ^ label is the index, count is the value of the list element
for i in range(len(bnds) - 1):
label2count_list.append((bnds[i] + bnds[i+1]) / 2.0)
label2count_list.append(bnds[-1])
#print("num_classes =", len(label2count_list))
l = interval_bounds
ascending = [l[i] <= l[i+1] for i in range(len(l)-1)]
assert all(ascending)
l = label2count_list
ascending = [l[i] <= l[i+1] for i in range(len(l)-1)]
assert all(ascending)
return interval_bounds, label2count_list
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yasuo_hao/S-DCNet.git
git@gitee.com:yasuo_hao/S-DCNet.git
yasuo_hao
S-DCNet
S-DCNet
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385