1 Star 0 Fork 0

shensanbai/time_series_experiment

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
forcasting.py 3.04 KB
一键复制 编辑 原始数据 按行查看 历史
import numpy as np
import time
from . import _eval_protocols as eval_protocols
def generate_pred_samples(features, data, pred_len, drop=0):
n = data.shape[1]
features = features[:, :-pred_len]
labels = np.stack([ data[:, i:1+n+i-pred_len] for i in range(pred_len)], axis=2)[:, 1:]
features = features[:, drop:]
labels = labels[:, drop:]
return features.reshape(-1, features.shape[-1]), \
labels.reshape(-1, labels.shape[2]*labels.shape[3])
def cal_metrics(pred, target):
return {
'MSE': ((pred - target) ** 2).mean(),
'MAE': np.abs(pred - target).mean()
}
def eval_forecasting(model, data, train_slice, valid_slice, test_slice, scaler, pred_lens, n_covariate_cols):
padding = 200
t = time.time()
all_repr = model.encode(
data,
mode='forecasting',
casual=True,
sliding_length=1,
sliding_padding=padding,
batch_size=256
)
train_repr = all_repr[:, train_slice]
valid_repr = all_repr[:, valid_slice]
test_repr = all_repr[:, test_slice]
train_data = data[:, train_slice, n_covariate_cols:]
valid_data = data[:, valid_slice, n_covariate_cols:]
test_data = data[:, test_slice, n_covariate_cols:]
encoder_infer_time = time.time() - t
ours_result = {}
lr_train_time = {}
lr_infer_time = {}
out_log = {}
for pred_len in pred_lens:
train_features, train_labels = generate_pred_samples(train_repr, train_data, pred_len, drop=padding)
valid_features, valid_labels = generate_pred_samples(valid_repr, valid_data, pred_len)
test_features, test_labels = generate_pred_samples(test_repr, test_data, pred_len)
t = time.time()
lr = eval_protocols.fit_ridge(train_features, train_labels, valid_features, valid_labels)
lr_train_time[pred_len] = time.time() - t
t = time.time()
test_pred = lr.predict(test_features)
lr_infer_time[pred_len] = time.time() - t
ori_shape = test_data.shape[0], -1, pred_len, test_data.shape[2]
test_pred = test_pred.reshape(ori_shape)
test_labels = test_labels.reshape(ori_shape)
if test_data.shape[0] > 1:
test_pred_inv = scaler.inverse_transform(test_pred.swapaxes(0, 3)).swapaxes(0, 3)
test_labels_inv = scaler.inverse_transform(test_labels.swapaxes(0, 3)).swapaxes(0, 3)
else:
test_pred_inv = scaler.inverse_transform(test_pred)
test_labels_inv = scaler.inverse_transform(test_labels)
out_log[pred_len] = {
'norm': test_pred,
'raw': test_pred_inv,
'norm_gt': test_labels,
'raw_gt': test_labels_inv
}
ours_result[pred_len] = {
'norm': cal_metrics(test_pred, test_labels),
'raw': cal_metrics(test_pred_inv, test_labels_inv)
}
eval_res = {
'ours': ours_result,
'encoder_infer_time': encoder_infer_time,
'lr_train_time': lr_train_time,
'lr_infer_time': lr_infer_time
}
return out_log, eval_res
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/shensanbai/time_series_experiment.git
git@gitee.com:shensanbai/time_series_experiment.git
shensanbai
time_series_experiment
time_series_experiment
master

搜索帮助