代码拉取完成,页面将自动刷新
NOTES_TRAIN = 150
NOTES_VALIDATE = 3
NUM_TREES = 200
NUM_ROUNDS = 10
NUM_USERS = 55540
print("Importing Libraries...")
import time
import torch
import numpy as np
from lightgbm import LGBMClassifier, log_evaluation
from joblib import dump
print("Importing Data...")
trainData = torch.load('./data/train.pt')
validateData = torch.load('./data/validate.pt')
def getClassifyData(data):
dataX = data[:, 1:]
dataY = data[:, 0]
return dataX, dataY
for round in range(NUM_ROUNDS):
print("Starting Round " + str(round+1) + "/" + str(NUM_ROUNDS) + "...")
print("Selecting Data...")
trainFrame = []
validateFrame = []
for id in range(NUM_USERS)[round::NUM_ROUNDS]:
trainFrame.append(trainData[150*id:NOTES_TRAIN+150*id])
validateFrame.append(validateData[5*id:NOTES_VALIDATE+5*id])
print("Processing Data...")
trainX, trainY = getClassifyData(torch.cat(trainFrame))
validateX, validateY = getClassifyData(torch.cat(validateFrame))
print("Training Model " + str(round+1) + "/" + str(NUM_ROUNDS) + "...")
clf = LGBMClassifier(boosting_type='goss', colsample_bytree=0.6933333333333332, learning_rate=0.1, \
max_bin=63, max_depth=-1, min_child_weight=7, min_data_in_leaf=20, \
min_split_gain=0.9473684210526315, n_estimators=NUM_TREES, \
num_leaves=33, reg_alpha=0.7894736842105263, reg_lambda=0.894736842105263, \
subsample=1, n_jobs=16, objective='multiclass', device_type='gpu')
start_time = time.time()
clf.fit(trainX, trainY.long(),
eval_set=[(validateX, validateY)],
eval_metric='multi_error',
callbacks=[log_evaluation()])
end_time = time.time()
print("Training Finished in %s Minutes" % ((end_time - start_time) / 60))
print("Saving Model " + str(round+1) + "/" + str(NUM_ROUNDS) + "...")
dump(clf, './models/layer2/model' + str(round) + '.pkl')
file = open("./stats/training/layer2/" + str(round) + ".txt", "w")
file.write(str(end_time - start_time))
file.close()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。