1 Star 0 Fork 0

ZENGWatermelon/TrackPrediction

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
trafficClassifier.py 5.08 KB
一键复制 编辑 原始数据 按行查看 历史
sunzhaoc 提交于 2020-11-30 21:37 . master
import sys
import os
import itertools
import random
# import Image # PIL
from PIL import Image
from libsvm.svm import svm_problem, svm_parameter
from libsvm.svmutil import * # libSVM
# from sklearn.externals import joblib
import joblib
# Image data constants
# DIMENSION = 32
# ROOT_DIR = "../images/"
# DAL = "dalmatian"
# DOLLAR = "dollar_bill"
# PIZZA = "pizza"
# BALL = "soccer_ball"
# FLOWER = "sunflower"
# CLASSES = [DAL, DOLLAR, PIZZA, BALL, FLOWER]
CLASSES = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19"]
# libsvm constants
LINEAR = 0
RBF = 2
# Other
USE_LINEAR = False
IS_TUNING = False
def main():
try:
# train, tune, test = getData(IS_TUNING)
train = joblib.load("./data/12train100.pkl")
# test = joblib.load("demo90.pkl")
test = joblib.load("./data/12test100.pkl")
# test["8"][0] = test["8"][0][0:42] + [0] * 3
models = getModels(train)
results = None
if IS_TUNING:
print("!!! TUNING MODE !!!")
# results = classify(models, tune)
else:
results = classify(models, test)
# print
totalCount = 0
totalCorrect = 0
for clazz in CLASSES:
count, correct = results[clazz]
totalCount += count
totalCorrect += correct
print("%s %d %d %f" % (clazz, correct, count, (float(correct) / count)))
print("%s %d %d %f" % ("Overall", totalCorrect, totalCount, (float(totalCorrect) / totalCount)))
except Exception as e: #异常
print(e)
return 5
def classify(models, dataSet):
results = {}
for trueClazz in CLASSES:
count = 0
correct = 0
for item in dataSet[trueClazz]:
predClazz, prob = predict(models, item)
print("%s,%s,%f" % (trueClazz, predClazz, prob))
count += 1
if trueClazz == str(predClazz): correct += 1
results[trueClazz] = (count, correct)
return results
def predict(models, item):
maxProb = 0.0
bestClass = ""
# for clazz, model in models.iteritems():
prob_list = []
for clazz, model in models.items(): #我改的
prob = predictSingle(model, item)
prob_list.append(prob)
# if prob > maxProb:
# maxProb = prob
# bestClass = clazz
# print(prob_list)
sumprob = sum(prob_list)
prob_list = [i/sumprob for i in prob_list]
maxProb = max(prob_list)
bestClass = prob_list.index(maxProb)
return (bestClass, maxProb)
def predictSingle(model, item):
output = svm_predict([0], [item], model, "-q -b 1")
prob = output[2][0][0]
return prob
def getModels(trainingData):
models = {}
param = getParam(USE_LINEAR)
for c in CLASSES:
labels, data = getTrainingData(trainingData, c)
prob = svm_problem(labels, data)
m = svm_train(prob, param)
models[c] = m
return models
def getTrainingData(trainingData, clazz):
labeledData = getLabeledDataVector(trainingData, clazz, 1)
negClasses = [c for c in CLASSES if not c == clazz]
for c in negClasses:
ld = getLabeledDataVector(trainingData, c, -1)
labeledData += ld
random.shuffle(labeledData)
unzipped = [list(t) for t in zip(*labeledData)]
labels, data = unzipped[0], unzipped[1]
return (labels, data)
def getParam(linear = True):
param = svm_parameter("-q")
param.probability = 1
if(linear):
param.kernel_type = LINEAR
param.C = .01
else:
param.kernel_type = RBF
param.C = .01
param.gamma = .00000001
return param
def getLabeledDataVector(dataset, clazz, label):
data = dataset[clazz]
labels = [label] * len(data)
output = zip(labels, data)
output = list(output) #我加的
return output
# def getData(generateTuningData):
# trainingData = {}
# tuneData = {}
# testData = {}
#
# for clazz in CLASSES:
# (train, tune, test) = buildTrainTestVectors(buildImageList(ROOT_DIR + clazz + "/"), generateTuningData)
# trainingData[clazz] = train
# tuneData[clazz] = tune
# testData[clazz] = test
#
# return (trainingData, tuneData, testData)
#
# def buildImageList(dirName):
# imgs = [Image.open(dirName + fileName).resize((DIMENSION, DIMENSION)) for fileName in os.listdir(dirName)]
# imgs = [list(itertools.chain.from_iterable(img.getdata())) for img in imgs]
# return imgs
#
# def buildTrainTestVectors(imgs, generateTuningData): #划分数据集
# # 70% for training, 30% for test.
# testSplit = int(.7 * len(imgs))
# baseTraining = imgs[:testSplit]
# test = imgs[testSplit:]
#
# training = None
# tuning = None
# if generateTuningData:
# # 50% of training for true training, 50% for tuning.
# tuneSplit = int(.5 * len(baseTraining))
# training = baseTraining[:tuneSplit]
# tuning = baseTraining[tuneSplit:]
# else:
# training = baseTraining
#
# return (training, tuning, test)
if __name__ == "__main__":
# sys.exit(main())
main()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/watermelonTT/TrackPrediction.git
git@gitee.com:watermelonTT/TrackPrediction.git
watermelonTT
TrackPrediction
TrackPrediction
master

搜索帮助