代码拉取完成,页面将自动刷新
import numpy as np
from grabscreen import grab_screen
import cv2
import time
import os
import pandas as pd
from tqdm import tqdm
from collections import deque
from models import inception_v3 as googlenet
from random import shuffle
FILE_I_END = 1860
WIDTH = 480
HEIGHT = 270
LR = 1e-3
EPOCHS = 30
MODEL_NAME = ''
PREV_MODEL = ''
LOAD_MODEL = True
wl = 0
sl = 0
al = 0
dl = 0
wal = 0
wdl = 0
sal = 0
sdl = 0
nkl = 0
w = [1,0,0,0,0,0,0,0,0]
s = [0,1,0,0,0,0,0,0,0]
a = [0,0,1,0,0,0,0,0,0]
d = [0,0,0,1,0,0,0,0,0]
wa = [0,0,0,0,1,0,0,0,0]
wd = [0,0,0,0,0,1,0,0,0]
sa = [0,0,0,0,0,0,1,0,0]
sd = [0,0,0,0,0,0,0,1,0]
nk = [0,0,0,0,0,0,0,0,1]
model = googlenet(WIDTH, HEIGHT, 3, LR, output=9, model_name=MODEL_NAME)
if LOAD_MODEL:
model.load(PREV_MODEL)
print('We have loaded a previous model!!!!')
# iterates through the training files
for e in range(EPOCHS):
#data_order = [i for i in range(1,FILE_I_END+1)]
data_order = [i for i in range(1,FILE_I_END+1)]
shuffle(data_order)
for count,i in enumerate(data_order):
try:
file_name = 'J:/phase10-random-padded/training_data-{}.npy'.format(i)
# full file info
train_data = np.load(file_name)
print('training_data-{}.npy'.format(i),len(train_data))
## # [ [ [FRAMES], CHOICE ] ]
## train_data = []
## current_frames = deque(maxlen=HM_FRAMES)
##
## for ds in data:
## screen, choice = ds
## gray_screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
##
##
## current_frames.append(gray_screen)
## if len(current_frames) == HM_FRAMES:
## train_data.append([list(current_frames),choice])
# #
# always validating unique data:
#shuffle(train_data)
train = train_data[:-50]
test = train_data[-50:]
X = np.array([i[0] for i in train]).reshape(-1,WIDTH,HEIGHT,3)
Y = [i[1] for i in train]
test_x = np.array([i[0] for i in test]).reshape(-1,WIDTH,HEIGHT,3)
test_y = [i[1] for i in test]
model.fit({'input': X}, {'targets': Y}, n_epoch=1, validation_set=({'input': test_x}, {'targets': test_y}),
snapshot_step=2500, show_metric=True, run_id=MODEL_NAME)
if count%10 == 0:
print('SAVING MODEL!')
model.save(MODEL_NAME)
except Exception as e:
print(str(e))
#
#tensorboard --logdir=foo:J:/phase10-code/log
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。