代码拉取完成,页面将自动刷新
#!/usr/bin/env python3
# Author: Armit
# Create Time: 2023/02/20
# demo app (local)
from pathlib import Path
from argparse import ArgumentParser
from traceback import print_exc
import tkinter as tk
import tkinter.ttk as ttk
import tkinter.messagebox as tkmsg
import tkinter.filedialog as tkfdlg
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from modules.transform import inv_transforms
from modules.predictor import Env, load_env, predict_with_oracle, predict_with_prediction
from modules.utils.config import Config
from modules.typing import TaskType, Seq, Stats
WINDOW_TITLE = 'Inplace Inference Demo'
WINDOW_SIZE = (1000, 750)
HIST_FIG_SIZE = (8, 8)
class App:
def __init__(self, args):
self.setup_gui()
self.args = args
self.env: Env = None
try:
self.wnd.mainloop()
except KeyboardInterrupt:
self.wnd.destroy()
except: print_exc()
def setup_gui(self):
# window
wnd = tk.Tk()
W, H = wnd.winfo_screenwidth(), wnd.winfo_screenheight()
w, h = WINDOW_SIZE
wnd.geometry(f'{w}x{h}+{(W-w)//2}+{(H-h)//2}')
wnd.resizable(False, False)
wnd.title(WINDOW_TITLE)
wnd.protocol('WM_DELETE_WINDOW', wnd.quit)
self.wnd = wnd
# Top: control
frm1 = ttk.Frame(wnd)
frm1.pack(side=tk.TOP, anchor=tk.N, expand=tk.YES, fill=tk.X)
if True:
self.var_job_file = tk.StringVar(frm1)
self.var_L = tk.IntVar(frm1, value=0)
self.var_R = tk.IntVar(frm1, value=0)
frm11 = ttk.LabelFrame(frm1, text='Job File')
frm11.pack(expand=tk.YES, fill=tk.X)
if True:
ent = ttk.Entry(frm11, textvariable=self.var_job_file, state=tk.DISABLED)
ent.pack(side=tk.LEFT, anchor=tk.W, expand=tk.YES, fill=tk.X)
lb = ttk.Button(frm11, text='Open..', command=self.open)
lb.pack(side=tk.RIGHT, anchor=tk.E)
frm12 = ttk.LabelFrame(frm1, text='Predict Range')
frm12.pack(expand=tk.YES, fill=tk.X)
if True:
self.sc_L = tk.Scale(frm12, variable=self.var_L, command=lambda _: self.plot(), from_=0, to=10, orient=tk.HORIZONTAL)
self.sc_L_pack = lambda: self.sc_L.pack(side=tk.LEFT, anchor=tk.W, expand=tk.YES, fill=tk.X)
self.sc_L_pack()
self.sc_R = tk.Scale(frm12, variable=self.var_R, command=lambda _: self.plot(), from_=0, to=10, orient=tk.HORIZONTAL)
self.sc_R_pack = lambda: self.sc_R.pack(side=tk.RIGHT, anchor=tk.E, expand=tk.YES, fill=tk.X)
self.sc_R_pack()
# bottom: plot
frm2 = ttk.Frame(wnd)
frm2.pack(side=tk.BOTTOM, expand=tk.YES, fill=tk.BOTH)
if True:
fig, ax = plt.subplots()
fig.set_size_inches(HIST_FIG_SIZE)
fig.tight_layout()
cvs = FigureCanvasTkAgg(fig, frm2)
cvs.get_tk_widget().pack(fill=tk.BOTH, expand=tk.YES)
self.fig, self.ax, self.cvs = fig, ax, cvs
def open(self):
fp = tkfdlg.askopenfilename(title='Open a job.yaml file..', filetypes=[('yaml files', '*.yaml')])
if not fp: return
fp = Path(fp)
if not fp.exists():
tkmsg.showerror('Error', 'File not exists!')
return
self.var_job_file.set(fp)
# init job
self.env = load_env(fp)
env: Env = self.env
job: Config = env.job
self.is_task_rgr = env.manager.TASK_TYPE == TaskType.RGR
print(f' is_task_rgr: {self.is_task_rgr}')
# precalc whole seq
seq: Seq = env.seq # transformed
label: Seq = env.label
stats: Stats = env.stats
self.preds_o: Seq = predict_with_oracle(env)
self.preds_o = np.pad(self.preds_o, (len(seq) - len(self.preds_o), 0), mode='edge')
if self.args.rolling:
self.preds_r: Seq = predict_with_prediction(env)
self.preds_r = np.pad(self.preds_o, (len(seq) - len(self.preds_r), 0), mode='edge')
if self.is_task_rgr:
self.truth = inv_transforms(seq, stats)
else:
self.truth = label
# set view range
print(f' seq.shape: {seq.shape}')
seqlen = len(seq)
inlen: int = job.get('dataset/inlen', 72)
res = max(seqlen // 100, inlen)
tick = min(10, round(seqlen // 10 / 100) * 100)
self.sc_L.configure(to=seqlen, resolution=res, tickinterval=tick) ; self.sc_L_pack()
self.sc_R.configure(to=seqlen, resolution=res, tickinterval=tick) ; self.sc_R_pack()
self.var_L.set(tick)
self.var_R.set(tick * 2)
self.plot()
def plot(self):
if self.env is None: return
L = self.var_L.get()
R = self.var_R.get()
if L >= R: return
if 'select range & channel':
truth = self.truth [L:R, 0]
preds_o = self.preds_o[L:R, 0] # [T'=R-L+1]
if self.args.rolling:
preds_r = self.preds_r[L:R, 0] # [T'=R-L+1]
if 'show metrics':
if self.is_task_rgr:
mae = np.mean(np.abs(truth - preds_o))
print(f'>> mae: {mae}')
else:
acc = np.mean(truth == preds_o)
print(f'>> acc: {acc:.3%}')
self.ax.cla()
self.ax.plot(truth, 'b', label='truth')
self.ax.plot(preds_o, 'r', label='pred (oracle)')
if self.args.rolling:
self.ax.plot(preds_r, 'g', label='pred (rolling)')
if not self.is_task_rgr:
self.ax.yaxis.set_major_locator(MaxNLocator(integer=True))
self.fig.legend()
self.fig.tight_layout()
self.cvs.draw()
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--rolling', action='store_true', help='whether draw rolling prediction')
args = parser.parse_args()
App(args)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。