1 Star 0 Fork 0

Kahsolt/PDB-analyze

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
utils.py 2.16 KB
一键复制 编辑 原始数据 按行查看 历史
Kahsolt 提交于 2023-01-11 18:16 . merge repo
#!/usr/bin/env python3
# Author: Armit
# Create Time: 2022/11/20
import os
from typing import Union, Iterable
from sklearn.metrics import *
from matplotlib.colors import ListedColormap
from data import cat_dict, TARGET
CPU_COUNT = os.cpu_count()
RAND_SEED = 42
def show_clf_metrics(y_true, y_pred):
print('=' * 78)
acc = accuracy_score (y_true, y_pred)
bacc = balanced_accuracy_score(y_true, y_pred)
print(f'Accuracy: {acc:.3%}')
print(f'Balanced Accuracy: {bacc:.3%}')
prec, recall, f1, supp = precision_recall_fscore_support(y_true, y_pred, average='macro')
print(f'Precision: {prec:.3%}')
print(f'Recall: {recall:.3%}')
print(f'F1-Score: {f1:.3%}')
print()
# target_names = ['stable', 'flexible']
target_names = ['Fl', 'Ot', 'St'] # 因为enumerate里边0是Fl,1是Ot,2是St
print(classification_report(y_true, y_pred, target_names=target_names))
print('Confusion Matrix:')
cm = confusion_matrix(y_true, y_pred)
print(cm)
print('=' * 78)
def get_cmap(n_colors:Union[int, Iterable]):
if isinstance(n_colors, Iterable):
n_colors = len(set(n_colors))
# cmap ref: https://matplotlib.org/2.0.2/examples/color/colormaps_reference.html
if n_colors <= 2: cmap = 'bwr'
elif n_colors <= 3: cmap = _make_cmap_rgb()
elif n_colors <= 4: cmap = _make_cmap_rgbx()
elif n_colors <= 8: cmap = 'Accent'
elif n_colors <= 10: cmap = 'tab10'
elif n_colors <= 12: cmap = 'Paired'
elif n_colors <= 20: cmap = 'tab20'
else: cmap = 'hsv'
return cmap
def _make_cmap_rgbx():
N = cat_dict.get_cat_ord(TARGET)
colors = [None] * N
colors[cat_dict.get_cat_id(TARGET, 'Fl')] = 'r'
colors[cat_dict.get_cat_id(TARGET, 'Ot')] = 'g'
colors[cat_dict.get_cat_id(TARGET, 'St')] = 'b'
colors[cat_dict.get_cat_id(TARGET, 'Undefined')] = 'black'
return ListedColormap(colors, name='rgbx', N=N)
def _make_cmap_rgb():
N = cat_dict.get_cat_ord(TARGET) - 1
colors = [None] * N
colors[cat_dict.get_cat_id(TARGET, 'Fl')] = 'r'
colors[cat_dict.get_cat_id(TARGET, 'Ot')] = 'g'
colors[cat_dict.get_cat_id(TARGET, 'St')] = 'b'
return ListedColormap(colors, name='rgb', N=N)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/kahsolt/pdb-analyze.git
git@gitee.com:kahsolt/pdb-analyze.git
kahsolt
pdb-analyze
PDB-analyze
master

搜索帮助