代码拉取完成,页面将自动刷新
import pickle
import os
import pandas as pd
import numpy as np
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from word2vec import Word2Vec
from log import Logger
logger = Logger(filename="./svm.log").get_logger()
def load_data(root, data_type) -> pd.DataFrame:
"""
load raw data
Example: df = load_data( root="./weibo", data_type="train" )
will load raw data from "./weibo/usual_train_labeled.csv"
:param root: data dir
:param data_type: ( "train","eval","test")
:return: pd.DataFrame
"""
pth = os.path.join(root, "usual_" + data_type + "_labeled.csv")
logger.info("load data from : " + str(pth))
_df = pd.read_csv(pth, encoding='utf-8')
return _df
def encode_label(data_list=None):
"""
label encoder
Example : [a,b,c,a] => [0,1,2,0]
:param data_list: data list
:return: labeled seq
"""
encoder = LabelEncoder()
label_index = None
if data_list is None:
raise ValueError("Error : data_list is null")
if not isinstance(data_list, list):
label_index = encoder.fit_transform(list(data_list))
# encoder.classes_
return label_index, encoder
class SVMModel:
def __init__(self, save_pth=None):
if save_pth is None:
raise ValueError("model saved pth failed")
self.save_pth = save_pth
self.clf = make_pipeline(StandardScaler(), SVC(gamma='auto'))
self.word2vec = Word2Vec()
def get_X_y(self, root, data_type):
_df = load_data(root=root, data_type=data_type)
y, label_encoder = encode_label(_df['情绪标签'].values)
X = list(_df['文本'].astype(str).map(self.word2vec.embedding))
logger.info("finish load " + data_type + " data . ")
return X, y
def train(self):
# train
logger.info("get x,y train data")
train_x, train_y = self.get_X_y(root="./weibo", data_type="train")
logger.warning("get x,y eval data")
eval_x, eval_y = self.get_X_y(root="./weibo", data_type="eval")
clf = self.clf.fit(train_x, train_y)
logger.info("finish train . ")
self.save_model(clf=clf)
_score = self.clf.score(eval_x, eval_y)
logger.info("finish eval . ")
return _score
def save_model(self, clf):
"""
save clf model to self.save_pth
:param clf: your model
:return: no return
"""
if self.save_pth is None:
raise ValueError("save file path is null")
with open(self.save_pth, 'wb') as f_writer:
pickle.dump(clf, f_writer)
print("save model in ", self.save_pth)
def load_model(self, pth=None):
"""
load clf model from pth path
if pth is None : will load model from self.save_pth
:param pth: model path
:return: model
"""
if pth is None:
if self.save_pth is not None:
pth = self.save_pth
else:
raise ValueError("save file path is null")
with open(pth, 'rb') as f_reader:
clf = pickle.load(f_reader)
return clf
if __name__ == '__main__':
model = SVMModel(save_pth="./svm.model")
score = model.train()
print(score)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。