代码拉取完成,页面将自动刷新
#!D:/Code/python
# -*- coding: utf-8 -*-
# @Time : 2021/5/7 0007 20:20
# @Author : xgf
# @File : method_DL_template.py
# @Software : PyCharm
import numpy
import torch
from torch import nn, Tensor, optim
from torch.autograd import Variable
import torch.nn.functional as F
from typing import (
TypeVar, Type, Union, Optional, Any,
List, Dict, Tuple, Callable, NamedTuple
)
import random
import time
import os
import copy
import re
import logging
from concurrent.futures import ThreadPoolExecutor
from concurrent import futures
import itertools
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from utils import Args, D, timeit
from log import Log
method_name = 'encrypted-traffic-analysis-2021'
mylog = Log('../encrypted-traffic-analysis-2021_log', method_name)
def get_args() -> Args:
"""
获取参数
"""
default_raw_data_dir = os.path.join(os.path.dirname(
__file__), "dataset/traffic")
default_feature_data_dir = os.path.join(os.path.dirname(
__file__), "dataset/traindata")
return Args([
D("batchSize", int, 32),
D("learningRate", float, 1e-3),
D("numEpochs", int, 1000),
D("rawDataDir", str, default_raw_data_dir),
D("dataDir", str, default_feature_data_dir),
D("saveDir", str, None),
D("nClass", int, 50),
D('splitData', int, 0.8),
])
def readNpy(file_path):
"""
读取npy文件
@param file_path:npy文件路径
@return:读取的npy内容,内容具体格式未知
"""
return np.load(file_path, allow_pickle=True)
class AverageMeter(object):
"""Computes and stores the average and current value
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class EncryptTrafficDataset(Dataset):
"""
特征数据集
"""
def __init__(self, traffic_data, transform=None, target_transform=None):
traffic_data = np.array(traffic_data)
traffic_data = traffic_data[:, :2]
try:
self.traffic_features = traffic_data[:, 0]
self.traffic_labels = traffic_data[:, 1]
except IOError:
print("EncryptTrafficDataset初始化数据集失败,因为数据集传入错误")
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.traffic_labels)
def __getitem__(self, idx):
feature = self.traffic_features[idx]
label = self.traffic_labels[idx]
if self.transform:
feature = self.transform(feature)
if self.target_transform:
label = self.target_transform(label)
feature = Tensor(feature)
# print(label)
# sample = {"feature": feature, "label": label, "lag": lag}
sample = {"feature": feature, "label": label}
return (feature, label)
class ThreeLinearNetwork(nn.Module):
"""
定义一个简单的三层神经网络用于测试
"""
def __init__(self):
super(ThreeLinearNetwork, self).__init__()
# self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
# nn.Linear(30, 512),
# nn.ReLU(),
# nn.Linear(512, 1024),
# nn.ReLU(),
# nn.Linear(1024, 512),
# nn.ReLU(),
# nn.Linear(512, 50),
# nn.ReLU()
nn.Linear(30, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 50),
nn.ReLU()
)
def forward(self, x):
# x = self.flatten(x)
# x = x.float()
logits = self.linear_relu_stack(x)
return logits
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
y.long()
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
# print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model):
size = len(dataloader.dataset)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
y.long()
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= size
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
return test_loss, correct
if __name__ == '__main__':
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
args = get_args()
batch_size = args.batchSize # 批的大小
learning_rate = args.learningRate # 学习率
num_epochs = args.numEpochs # 遍历训练集的次数
data_dir = args.dataDir
save_dir = args.saveDir
n_class = args.nClass
split_data = args.splitData
# # 读取数据
# # data = readNpy('./feature_extraction/undefence_features.npy')
traindata = readNpy('./feature_extraction/undefence_90.npy')
testdata = readNpy('./feature_extraction/undefence_10.npy')
# traffic_data = EncryptTrafficDataset(data)
# print(traffic_data[0])
# # 划分数据
# train_size = int(split_data * len(traffic_data))
# test_size = len(traffic_data) - train_size
# train_data, test_data = torch.utils.data.random_split(traffic_data, [train_size, test_size])
train_data, test_data = EncryptTrafficDataset(traindata), EncryptTrafficDataset(testdata)
# 定义dataloader
train_dataloader = DataLoader(train_data, batch_size= batch_size, shuffle=True)
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
test_dataloader = DataLoader(test_data, batch_size= batch_size, shuffle=True)
model = ThreeLinearNetwork().to(device)
print(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr= learning_rate)
for t in range(num_epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train_loss = train(train_dataloader, model, loss_fn, optimizer)
train_loss, train_test = test(train_dataloader, model)
test_loss, test_acc = test(test_dataloader, model)
mylog.state_dict_update([('train_loss_list', train_loss),
('train_acc_list', test_acc),
('valid_loss_list', test_loss),
('valid_acc_list', test_acc),
])
print("Done!")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。