2 Star 0 Fork 0

unknow1216/sklearn-text-Classifier

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
split_data.py 3.43 KB
一键复制 编辑 原始数据 按行查看 历史
unknow1216 提交于 2020-12-23 15:30 . finally commit
import os
import csv
import random
import pandas as pd
import numpy as np
def readList(path): #读取一个csv,输出样本列表,元素形式[{"title"}...]
data = pd.read_csv(path,encoding='utf-8')
col = data["title"]
sample_list = np.array(col)
return sample_list
def getDataSet(list,proportion):
"""
:exception
获取训练集和测试集(将数据按比例随机划分)
:parameter
proportion - 测试集/数据集
:return
trainDataSet - 训练集
testDataSet - 测试集
"""
# dataSet = open('数据集.csv')
# dataSetReader = csv.reader(dataSet)
# Lists = os.listdir(path)
# for eachFile in Lists:
# eachPathFile = path + eachFile
# dataSetReader = readList(path)
"""
:exception
将数据保存到数组
"""
dataSet = []
for item in list:
dataSet.append(item) # 我也不晓得为何要新建[],再append进去,直接导入list作数据集不行?
# next(dataSetReader, 'none') # 跳过表头,但由于基于表头输出则本身导入数组不带表头
# data = next(dataSetReader, 'none')
# print(dataSet)
"""
:exception
按照比例随机划分出训练集和测试集
"""
dataNumber = dataSet.__len__() # 数据集数据条数
testNumber = int(dataNumber * proportion) # 测试集数据条数
testDataSet = [] # 测试数据集
trainDataSet = [] # 训练数据集
testDataSet = random.sample(dataSet, testNumber) # 测试集 ,random中使用其实不是list,而是set
for testData in testDataSet: # 将已经选定的测试集数据从数据集中删除
dataSet.remove(testData)
trainDataSet = dataSet # 训练集
# print(trainDataSet)
# print('--------------------------')
# print(testDataSet)
return trainDataSet, testDataSet
def segText(inputPath, resultPath):
fatherLists = os.listdir(inputPath) # 主目录
for eachDir in fatherLists: # 遍历主目录中各个文件夹
eachPath = inputPath + eachDir + "/" # 保存主目录中每个文件夹目录,便于遍历二级文件
childLists = os.listdir(eachPath) # 获取每个文件夹中的各个文件
for eachFile in childLists: # 遍历数据集下的csv文件
total = []
eachPathFile = eachPath + eachFile
content = readList(eachPathFile)
# print(content)
total = getDataSet(content,0.1)
# print(total[0])
num = 0
for items in total:
each_resultPath = resultPath[num] + eachDir + "/" # 数据集文件存入的目录
if not os.path.exists(each_resultPath):
os.makedirs(each_resultPath)
csvfile = open(each_resultPath + eachFile, 'w', encoding='utf-8-sig', newline='')
fieldnames = ['title']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
if not os.path.exists(each_resultPath):
os.makedirs(each_resultPath)
for item in items:
saveascsv(writer,item)
num+=1
def saveascsv(writer,title):
writer.writeheader()
item = {}
item['title'] = title
# print(item)
writer.writerow(item)
if __name__ == '__main__':
data_path = './total/' # 数据集所在路径
set_path = ['./train/','./test/']
segText(data_path,set_path)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/unknow1216/sklearn-text-classifier.git
git@gitee.com:unknow1216/sklearn-text-classifier.git
unknow1216
sklearn-text-classifier
sklearn-text-Classifier
master

搜索帮助