1 Star 0 Fork 0

Kahsolt/PDB-analyze

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
cluster.py 2.49 KB
一键复制 编辑 原始数据 按行查看 历史
Kahsolt 提交于 2023-01-11 18:16 . merge repo
#!/usr/bin/env python3
# Author: Armit
# Create Time: 2022/11/19
from argparse import ArgumentParser
from sklearnex import patch_sklearn ; patch_sklearn()
from sklearn.cluster import *
import matplotlib.pylab as plt
from data import get_data, FEATURE_CAT, FEATURE_NUM, TARGET, cat_dict
from utils import get_cmap
from pca import _pca
METHODS = {
'kmeans': lambda: KMeans(n_clusters=args.n_cluster, verbose=2, random_state=42),
'bs-kemans': lambda: BisectingKMeans(n_clusters=args.n_cluster, random_state=42, verbose=2),
'mb-kmeans': lambda: MiniBatchKMeans(n_clusters=args.n_cluster, verbose=2, random_state=42, reassignment_ratio=0.03),
'agg': lambda: AgglomerativeClustering(n_clusters=args.n_cluster),
'fagg': lambda: FeatureAgglomeration(n_clusters=args.n_cluster),
'ap': lambda: AffinityPropagation(),
'brich': lambda: Birch(n_clusters=args.n_cluster),
'meanshift': lambda: MeanShift(),
'dbscan': lambda: DBSCAN(p=2),
'optics': lambda: OPTICS(),
'spec': lambda: SpectralClustering(n_clusters=args.n_cluster, random_state=42, verbose=True),
'spec-b': lambda: SpectralBiclustering(n_clusters=args.n_cluster, random_state=42),
'spec-c': lambda: SpectralCoclustering(n_clusters=args.n_cluster, random_state=42),
}
def cluster(args):
X, Y = get_data(limit=args.limit, features=FEATURE_NUM, target=args.target)
n_cluster = len(set(Y))
print(f'[{args.method}] clustering')
model = METHODS[args.method]()
pred = model.fit_predict(X)
if hasattr(model, 'inertia_'): print(f' inertia: {model.inertia_}')
X_hat = _pca(X)
x_min, x_max = X_hat[:, 0].min(), X_hat[:, 0].max()
y_min, y_max = X_hat[:, 1].min(), X_hat[:, 1].max()
cmap = get_cmap(n_cluster)
plt.subplot(211); plt.title('pred') ; plt.xlim(x_min, x_max) ; plt.ylim(y_min, y_max) ; plt.scatter(X_hat[:, 0], X_hat[:, 1], s=1, cmap=cmap, c=pred)
plt.subplot(212); plt.title('truth') ; plt.xlim(x_min, x_max) ; plt.ylim(y_min, y_max) ; plt.scatter(X_hat[:, 0], X_hat[:, 1], s=1, cmap=cmap, c=Y)
plt.tight_layout()
plt.show()
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('-M', '--method', default='kmeans', choices=METHODS.keys())
parser.add_argument('-T', '--target', default=TARGET, choices=FEATURE_CAT)
parser.add_argument('--n_cluster', type=int)
parser.add_argument('-N', '--limit', default=20000, type=int, help='limit dataset size')
args = parser.parse_args()
args.n_cluster = args.n_cluster or cat_dict.get_cat_ord(args.target)
cluster(args)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/kahsolt/pdb-analyze.git
git@gitee.com:kahsolt/pdb-analyze.git
kahsolt
pdb-analyze
PDB-analyze
master

搜索帮助