代码拉取完成,页面将自动刷新
# %%
import torch
from torch import nn, Tensor, optim
from torch.autograd import Variable
import torch.nn.functional as F
from typing import (
TypeVar, Type, Union, Optional, Any,
List, Dict, Tuple, Callable, NamedTuple
)
import numpy as np
import random
import time
import os
import copy
import re
import logging
from concurrent.futures import ThreadPoolExecutor
from concurrent import futures
import itertools
from utils import Args, D, timeit
logger = logging.getLogger(__name__)
logging.basicConfig(
level=10, format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s')
def inf_loop() -> None:
while True:
yield None
class Sample(NamedTuple):
"""
数据集中的sample,由数据和标签构成
tag 表示该数据来源,例如:
训练集下的 49-15 sample 的 tag 为 train-49-15
"""
data: List[Tuple[float, int]] # 数据文件中所有传输记录 [Time,UpOrDown]
label: int # 标签, 0-49
tag: str # 记号
def to_tensor_sample(self) -> Tuple[Tensor, int]:
"""
转为 (Tensor,int),数据部分为 (Channel,L) 格式
时间和上下行流量标签各占一个 channel
:return:
"""
_data = Tensor(self.data).t()
_label = self.label
return _data, _label
class RawDataSet(NamedTuple):
"""
原始数据集
"""
train: List[Sample]
test: List[Sample]
def read_data(data_dir: str, max_workers: int = 12,
num_train: Optional[int] = None, num_test: Optional[int] = None) -> RawDataSet:
"""
读取原始数据,全部加载到内存。
:param data_dir: 数据集目录
:param max_workers: 读取数据的线程数
:param num_train: 读取的训练文件数, None 表示全部
:param num_test: 读取的测试文件数, None 表示全部
:return: 数据集
"""
train_dir = os.path.join(data_dir, "defence")
test_dir = os.path.join(data_dir, "undefence")
train_files = os.listdir(train_dir)
test_files = os.listdir(test_dir)
file_name_pattern = re.compile(r"(\d+?)\-(\d+?)")
def fn_train_tag(s):
return f"train-{s}"
def fn_test_tag(s):
return f"test-{s}"
# print(len(train_files), len(test_files)) # 4500 4500
def build_raw_dataset(_dir: str, files: List[str], fn_tag: Callable[[str], str],
num_samples: Optional[int] = None) -> List[Sample]:
samples = [None] * len(files)
def f(file_name: str, idx: int):
_res = file_name_pattern.findall(file_name)
if len(_res) <= 0:
# 文件名不符合规定,直接跳过
return
file_path = os.path.join(_dir, file_name)
_label, _ = _res[0]
_label = int(_label)
_data = list()
with open(file_path, "r") as fr:
for line in fr.readlines(): # 依次读取每行
line = line.strip() # 去掉每行头尾空白
_time_str, _stream_type_str = tuple(
re.split(r"\s+?", line))
_time, _stream_type = (
float(_time_str), int(_stream_type_str))
_data.append((_time, _stream_type))
assert len(_data) > 0 # _data 不能为空
sample = Sample(_data, _label, fn_tag(file_name))
# samples.append(sample)
samples[idx] = sample
with ThreadPoolExecutor(max_workers=max_workers) as executor:
_futures = []
for i, file_name in enumerate(files):
if num_samples is not None and i >= num_samples:
break
fut = executor.submit(f, file_name, i)
_futures.append(fut)
futures.wait(_futures, return_when=futures.ALL_COMPLETED)
# 去除读取失败的
samples = [ele for ele in samples if ele is not None]
return samples
train_dataset = build_raw_dataset(
train_dir, train_files, fn_train_tag, num_train)
test_dataset = build_raw_dataset(
test_dir, test_files, fn_test_tag, num_test)
return RawDataSet(train_dataset, test_dataset)
def calc_accuracy(fn_predict: Callable[[Sample], int], sample_ls: List[Sample]) -> float:
"""
:param fn_predict: 需要测试的函数
:param sample_ls: 测试数据集(sample的列表)
:return: top1 准确率
"""
n = 0
right = 0
for sample in sample_ls:
n += 1
pred = fn_predict(sample)
if pred == sample.label:
right += 1
acc = right / n
return acc
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。