1 Star 0 Fork 0

伊拉克肥灵/identification

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
1-train_layer_1.py 2.05 KB
一键复制 编辑 原始数据 按行查看 历史
Vivek Nair 提交于 2023-02-17 08:38 . training
NOTES_TRAIN = 150
NOTES_VALIDATE = 3
NUM_TREES = 200
NUM_ROUNDS = 10
NUM_USERS = 55540
print("Importing Libraries...")
import time
import torch
import numpy as np
from lightgbm import LGBMClassifier, log_evaluation
from joblib import dump
print("Importing Data...")
trainData = torch.load('./data/train.pt')
validateData = torch.load('./data/validate.pt')
def getClassifyData(data):
dataX = data[:, 1:]
dataY = data[:, 0]
return dataX, dataY
for round in range(NUM_ROUNDS):
print("Starting Round " + str(round+1) + "/" + str(NUM_ROUNDS) + "...")
print("Selecting Data...")
trainFrame = []
validateFrame = []
users_per_round = NUM_USERS // NUM_ROUNDS
for id in range(users_per_round*round, users_per_round*(round+1)):
trainFrame.append(trainData[150*id:NOTES_TRAIN+150*id])
validateFrame.append(validateData[5*id:NOTES_VALIDATE+5*id])
print("Processing Data...")
trainX, trainY = getClassifyData(torch.cat(trainFrame))
validateX, validateY = getClassifyData(torch.cat(validateFrame))
print("Training Model " + str(round+1) + "/" + str(NUM_ROUNDS) + "...")
clf = LGBMClassifier(boosting_type='goss', colsample_bytree=0.6933333333333332, learning_rate=0.1, \
max_bin=63, max_depth=-1, min_child_weight=7, min_data_in_leaf=20, \
min_split_gain=0.9473684210526315, n_estimators=NUM_TREES, \
num_leaves=33, reg_alpha=0.7894736842105263, reg_lambda=0.894736842105263, \
subsample=1, n_jobs=16, objective='multiclass', device_type='gpu')
start_time = time.time()
clf.fit(trainX, trainY.long(),
eval_set=[(validateX, validateY)],
eval_metric='multi_error',
callbacks=[log_evaluation()])
end_time = time.time()
print("Training Finished in %s Minutes" % ((end_time - start_time) / 60))
print("Saving Model " + str(round+1) + "/" + str(NUM_ROUNDS) + "...")
dump(clf, './models/layer1/model' + str(round) + '.pkl')
file = open("./stats/training/layer1/" + str(round) + ".txt", "w")
file.write(str(end_time - start_time))
file.close()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/estaryjl/identification.git
git@gitee.com:estaryjl/identification.git
estaryjl
identification
identification
main

搜索帮助

0d507c66 1850385 C8b1a773 1850385