1 Star 1 Fork 0

coshpr/emotion_classification

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
svm.py 3.37 KB
一键复制 编辑 原始数据 按行查看 历史
coshpr 提交于 2022-04-18 09:33 . add model: svm bert and lstm
import pickle
import os
import pandas as pd
import numpy as np
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from word2vec import Word2Vec
from log import Logger
logger = Logger(filename="./svm.log").get_logger()
def load_data(root, data_type) -> pd.DataFrame:
"""
load raw data
Example: df = load_data( root="./weibo", data_type="train" )
will load raw data from "./weibo/usual_train_labeled.csv"
:param root: data dir
:param data_type: ( "train","eval","test")
:return: pd.DataFrame
"""
pth = os.path.join(root, "usual_" + data_type + "_labeled.csv")
logger.info("load data from : " + str(pth))
_df = pd.read_csv(pth, encoding='utf-8')
return _df
def encode_label(data_list=None):
"""
label encoder
Example : [a,b,c,a] => [0,1,2,0]
:param data_list: data list
:return: labeled seq
"""
encoder = LabelEncoder()
label_index = None
if data_list is None:
raise ValueError("Error : data_list is null")
if not isinstance(data_list, list):
label_index = encoder.fit_transform(list(data_list))
# encoder.classes_
return label_index, encoder
class SVMModel:
def __init__(self, save_pth=None):
if save_pth is None:
raise ValueError("model saved pth failed")
self.save_pth = save_pth
self.clf = make_pipeline(StandardScaler(), SVC(gamma='auto'))
self.word2vec = Word2Vec()
def get_X_y(self, root, data_type):
_df = load_data(root=root, data_type=data_type)
y, label_encoder = encode_label(_df['情绪标签'].values)
X = list(_df['文本'].astype(str).map(self.word2vec.embedding))
logger.info("finish load " + data_type + " data . ")
return X, y
def train(self):
# train
logger.info("get x,y train data")
train_x, train_y = self.get_X_y(root="./weibo", data_type="train")
logger.warning("get x,y eval data")
eval_x, eval_y = self.get_X_y(root="./weibo", data_type="eval")
clf = self.clf.fit(train_x, train_y)
logger.info("finish train . ")
self.save_model(clf=clf)
_score = self.clf.score(eval_x, eval_y)
logger.info("finish eval . ")
return _score
def save_model(self, clf):
"""
save clf model to self.save_pth
:param clf: your model
:return: no return
"""
if self.save_pth is None:
raise ValueError("save file path is null")
with open(self.save_pth, 'wb') as f_writer:
pickle.dump(clf, f_writer)
print("save model in ", self.save_pth)
def load_model(self, pth=None):
"""
load clf model from pth path
if pth is None : will load model from self.save_pth
:param pth: model path
:return: model
"""
if pth is None:
if self.save_pth is not None:
pth = self.save_pth
else:
raise ValueError("save file path is null")
with open(pth, 'rb') as f_reader:
clf = pickle.load(f_reader)
return clf
if __name__ == '__main__':
model = SVMModel(save_pth="./svm.model")
score = model.train()
print(score)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/coshpr/emotion_classification.git
git@gitee.com:coshpr/emotion_classification.git
coshpr
emotion_classification
emotion_classification
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385