4 Star 1 Fork 0

Green/encrypted-traffic-analysis-2021

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
exp_tools.py 4.56 KB
一键复制 编辑 原始数据 按行查看 历史
Green 提交于 2021-05-11 22:36 . conv1d 02
# %%
import torch
from torch import nn, Tensor, optim
from torch.autograd import Variable
import torch.nn.functional as F
from typing import (
TypeVar, Type, Union, Optional, Any,
List, Dict, Tuple, Callable, NamedTuple
)
import numpy as np
import random
import time
import os
import copy
import re
import logging
from concurrent.futures import ThreadPoolExecutor
from concurrent import futures
import itertools
from utils import Args, D, timeit
logger = logging.getLogger(__name__)
logging.basicConfig(
level=10, format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s')
def inf_loop() -> None:
while True:
yield None
class Sample(NamedTuple):
"""
数据集中的sample,由数据和标签构成
tag 表示该数据来源,例如:
训练集下的 49-15 sample 的 tag 为 train-49-15
"""
data: List[Tuple[float, int]] # 数据文件中所有传输记录 [Time,UpOrDown]
label: int # 标签, 0-49
tag: str # 记号
def to_tensor_sample(self) -> Tuple[Tensor, int]:
"""
转为 (Tensor,int),数据部分为 (Channel,L) 格式
时间和上下行流量标签各占一个 channel
:return:
"""
_data = Tensor(self.data).t()
_label = self.label
return _data, _label
class RawDataSet(NamedTuple):
"""
原始数据集
"""
train: List[Sample]
test: List[Sample]
def read_data(data_dir: str, max_workers: int = 12,
num_train: Optional[int] = None, num_test: Optional[int] = None) -> RawDataSet:
"""
读取原始数据,全部加载到内存。
:param data_dir: 数据集目录
:param max_workers: 读取数据的线程数
:param num_train: 读取的训练文件数, None 表示全部
:param num_test: 读取的测试文件数, None 表示全部
:return: 数据集
"""
train_dir = os.path.join(data_dir, "defence")
test_dir = os.path.join(data_dir, "undefence")
train_files = os.listdir(train_dir)
test_files = os.listdir(test_dir)
file_name_pattern = re.compile(r"(\d+?)\-(\d+?)")
def fn_train_tag(s):
return f"train-{s}"
def fn_test_tag(s):
return f"test-{s}"
# print(len(train_files), len(test_files)) # 4500 4500
def build_raw_dataset(_dir: str, files: List[str], fn_tag: Callable[[str], str],
num_samples: Optional[int] = None) -> List[Sample]:
samples = [None] * len(files)
def f(file_name: str, idx: int):
_res = file_name_pattern.findall(file_name)
if len(_res) <= 0:
# 文件名不符合规定,直接跳过
return
file_path = os.path.join(_dir, file_name)
_label, _ = _res[0]
_label = int(_label)
_data = list()
with open(file_path, "r") as fr:
for line in fr.readlines(): # 依次读取每行
line = line.strip() # 去掉每行头尾空白
_time_str, _stream_type_str = tuple(
re.split(r"\s+?", line))
_time, _stream_type = (
float(_time_str), int(_stream_type_str))
_data.append((_time, _stream_type))
assert len(_data) > 0 # _data 不能为空
sample = Sample(_data, _label, fn_tag(file_name))
# samples.append(sample)
samples[idx] = sample
with ThreadPoolExecutor(max_workers=max_workers) as executor:
_futures = []
for i, file_name in enumerate(files):
if num_samples is not None and i >= num_samples:
break
fut = executor.submit(f, file_name, i)
_futures.append(fut)
futures.wait(_futures, return_when=futures.ALL_COMPLETED)
# 去除读取失败的
samples = [ele for ele in samples if ele is not None]
return samples
train_dataset = build_raw_dataset(
train_dir, train_files, fn_train_tag, num_train)
test_dataset = build_raw_dataset(
test_dir, test_files, fn_test_tag, num_test)
return RawDataSet(train_dataset, test_dataset)
def calc_accuracy(fn_predict: Callable[[Sample], int], sample_ls: List[Sample]) -> float:
"""
:param fn_predict: 需要测试的函数
:param sample_ls: 测试数据集(sample的列表)
:return: top1 准确率
"""
n = 0
right = 0
for sample in sample_ls:
n += 1
pred = fn_predict(sample)
if pred == sample.label:
right += 1
acc = right / n
return acc
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/bitosky/encrypted-traffic-analysis-2021.git
git@gitee.com:bitosky/encrypted-traffic-analysis-2021.git
bitosky
encrypted-traffic-analysis-2021
encrypted-traffic-analysis-2021
master

搜索帮助