代码拉取完成,页面将自动刷新
#!/usr/bin/env python3
# Author: Armit
# Create Time: 2022/11/19
from argparse import ArgumentParser
from sklearnex import patch_sklearn ; patch_sklearn()
from sklearn.cluster import *
import matplotlib.pylab as plt
from data import get_data, FEATURE_CAT, FEATURE_NUM, TARGET, cat_dict
from utils import get_cmap
from pca import _pca
METHODS = {
'kmeans': lambda: KMeans(n_clusters=args.n_cluster, verbose=2, random_state=42),
'bs-kemans': lambda: BisectingKMeans(n_clusters=args.n_cluster, random_state=42, verbose=2),
'mb-kmeans': lambda: MiniBatchKMeans(n_clusters=args.n_cluster, verbose=2, random_state=42, reassignment_ratio=0.03),
'agg': lambda: AgglomerativeClustering(n_clusters=args.n_cluster),
'fagg': lambda: FeatureAgglomeration(n_clusters=args.n_cluster),
'ap': lambda: AffinityPropagation(),
'brich': lambda: Birch(n_clusters=args.n_cluster),
'meanshift': lambda: MeanShift(),
'dbscan': lambda: DBSCAN(p=2),
'optics': lambda: OPTICS(),
'spec': lambda: SpectralClustering(n_clusters=args.n_cluster, random_state=42, verbose=True),
'spec-b': lambda: SpectralBiclustering(n_clusters=args.n_cluster, random_state=42),
'spec-c': lambda: SpectralCoclustering(n_clusters=args.n_cluster, random_state=42),
}
def cluster(args):
X, Y = get_data(limit=args.limit, features=FEATURE_NUM, target=args.target)
n_cluster = len(set(Y))
print(f'[{args.method}] clustering')
model = METHODS[args.method]()
pred = model.fit_predict(X)
if hasattr(model, 'inertia_'): print(f' inertia: {model.inertia_}')
X_hat = _pca(X)
x_min, x_max = X_hat[:, 0].min(), X_hat[:, 0].max()
y_min, y_max = X_hat[:, 1].min(), X_hat[:, 1].max()
cmap = get_cmap(n_cluster)
plt.subplot(211); plt.title('pred') ; plt.xlim(x_min, x_max) ; plt.ylim(y_min, y_max) ; plt.scatter(X_hat[:, 0], X_hat[:, 1], s=1, cmap=cmap, c=pred)
plt.subplot(212); plt.title('truth') ; plt.xlim(x_min, x_max) ; plt.ylim(y_min, y_max) ; plt.scatter(X_hat[:, 0], X_hat[:, 1], s=1, cmap=cmap, c=Y)
plt.tight_layout()
plt.show()
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('-M', '--method', default='kmeans', choices=METHODS.keys())
parser.add_argument('-T', '--target', default=TARGET, choices=FEATURE_CAT)
parser.add_argument('--n_cluster', type=int)
parser.add_argument('-N', '--limit', default=20000, type=int, help='limit dataset size')
args = parser.parse_args()
args.n_cluster = args.n_cluster or cat_dict.get_cat_ord(args.target)
cluster(args)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。