代码拉取完成,页面将自动刷新
clear;
close all;
rng('default')
%% SET UP ENVIRONMENT
% Speedup options
useFastRestart = true;
useGPU = false;
useParallel = false;
% rng(0)
rng(30)
disp('continue train begin')
%% 环境配置 (是不是有stable baseline内味儿了??)
rpos = [-0.5,0.5];
env = WaveEnvConAction(rpos,false);
validateEnvironment(env);
%% 状态和动作配置
observationInfo = env.getObservationInfo;
numObservations = observationInfo.Dimension(1);
actionInfo = env.getActionInfo;
numActions = actionInfo.Dimension(1);
statePath = [featureInputLayer(numObservations,'Normalization','none','Name','State')
fullyConnectedLayer(64,'Name','fc1')];
actionPath = [featureInputLayer(numActions, 'Normalization', 'none', 'Name','Action')
fullyConnectedLayer(64, 'Name','fc2')];
commonPath = [additionLayer(2,'Name','add')
reluLayer('Name','relu2')
fullyConnectedLayer(32, 'Name','fc3')
reluLayer('Name','relu3')
fullyConnectedLayer(16, 'Name','fc4')
fullyConnectedLayer(1, 'Name','CriticOutput')];
criticNetwork = layerGraph();
criticNetwork = addLayers(criticNetwork,statePath);
criticNetwork = addLayers(criticNetwork,actionPath);
criticNetwork = addLayers(criticNetwork,commonPath);
criticNetwork = connectLayers(criticNetwork,'fc1','add/in1');
criticNetwork = connectLayers(criticNetwork,'fc2','add/in2');
criticOptions = rlRepresentationOptions('LearnRate',1e-3,'GradientThreshold',1);
critic1 = rlQValueRepresentation(criticNetwork,observationInfo,actionInfo,...
'Observation',{'State'},'Action',{'Action'},criticOptions);
critic2 = rlQValueRepresentation(criticNetwork,observationInfo,actionInfo,...
'Observation',{'State'},'Action',{'Action'},criticOptions);
actorNetwork = [featureInputLayer(numObservations,'Normalization','none','Name','State')
fullyConnectedLayer(64, 'Name','actorFC1')
reluLayer('Name','relu1')
fullyConnectedLayer(32, 'Name','actorFC2')
reluLayer('Name','relu2')
fullyConnectedLayer(numActions,'Name','Action')
tanhLayer('Name','tanh1')];
actorOptions = rlRepresentationOptions('LearnRate',2e-4,'GradientThreshold',1,'L2RegularizationFactor',0.001);
actor = rlDeterministicActorRepresentation(actorNetwork,observationInfo,actionInfo,...
'Observation',{'State'},'Action',{'tanh1'},actorOptions);
Ts = 0.1;
Tf = 20;
Ts_agent = Ts;
agentOptions = rlTD3AgentOptions("SampleTime",Ts_agent, ...
"DiscountFactor", 0.998, ...
"ExperienceBufferLength",2e6, ...
"MiniBatchSize",64, ...
"NumStepsToLookAhead",1, ...
"TargetSmoothFactor",0.005, ...
"TargetUpdateFrequency",10);
agentOptions.ExplorationModel.Variance = 0.1;
agentOptions.ExplorationModel.VarianceDecayRate = 2e-4;
agentOptions.ExplorationModel.VarianceMin = 0.001;
agentOptions.TargetPolicySmoothModel.Variance = 0.1;
agentOptions.TargetPolicySmoothModel.VarianceDecayRate = 1e-4;
agent = rlTD3Agent(actor,[critic1,critic2],agentOptions);
maxsteps = ceil(Tf/Ts_agent);
trainingOpts = rlTrainingOptions(...
'MaxEpisodes',500,...
'MaxStepsPerEpisode',ceil(Tf/Ts),...
'ScoreAveragingWindowLength',200,...
'Plots','training-progress',...
'StopTrainingCriteria','EpisodeReward',...
'UseParallel',true,...
'StopTrainingValue',100);
trainOpts.UseParallel = true;
trainOpts.ParallelizationOptions.Mode = 'async';
trainOpts.ParallelizationOptions.StepsUntilDataIsSent = 32;
trainOpts.ParallelizationOptions.DataToSendFromWorkers = 'Experiences';
% trainingOpts = rlTrainingOptions(...
% 'MaxEpisodes',maxepisodes, ...
% 'MaxStepsPerEpisode',maxsteps, ...
% 'StopTrainingCriteria','AverageReward',...
% 'StopTrainingValue',-190,...
% 'ScoreAveragingWindowLength',100);
doTraining = true;
if doTraining
trainingStats = train(agent,env,trainingOpts);
else
load('waveconagent.mat',agent)
end
%%
% load('waveconagent.mat')
Ts = 0.1;
Tf = 20;
rpos = [-0.5,0.5];
env = WaveEnvConAction(rpos,true);
simOptions = rlSimulationOptions('MaxSteps',ceil(Tf/Ts));
experience = sim(env,agent,simOptions);
figure
plot(experience.Reward)
figure
% plot(experience.Observation.observation)
% 1.4648e-07
plot(env.HisIntensities(1,:));
figure
plot(env.HisAngle(1,:))
hold on
plot(env.HisAngle(2,:))
legend("yaw","pitch")
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。