1 Star 0 Fork 0

Briefly/rldemo_paper_code

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
WaveConTrain.m 4.13 KB
一键复制 编辑 原始数据 按行查看 历史
Briefly 提交于 2023-09-10 17:56 . 添加所有文件
clear;
close all;
rng('default')
%% SET UP ENVIRONMENT
% Speedup options
useFastRestart = true;
useGPU = false;
useParallel = false;
% rng(0)
rng(30)
disp('continue train begin')
%% 环境配置 (是不是有stable baseline内味儿了??)
rpos = [-0.5,0.5];
env = WaveEnvConAction(rpos,false);
validateEnvironment(env);
%% 状态和动作配置
observationInfo = env.getObservationInfo;
numObservations = observationInfo.Dimension(1);
actionInfo = env.getActionInfo;
numActions = actionInfo.Dimension(1);
statePath = [featureInputLayer(numObservations,'Normalization','none','Name','State')
fullyConnectedLayer(64,'Name','fc1')];
actionPath = [featureInputLayer(numActions, 'Normalization', 'none', 'Name','Action')
fullyConnectedLayer(64, 'Name','fc2')];
commonPath = [additionLayer(2,'Name','add')
reluLayer('Name','relu2')
fullyConnectedLayer(32, 'Name','fc3')
reluLayer('Name','relu3')
fullyConnectedLayer(16, 'Name','fc4')
fullyConnectedLayer(1, 'Name','CriticOutput')];
criticNetwork = layerGraph();
criticNetwork = addLayers(criticNetwork,statePath);
criticNetwork = addLayers(criticNetwork,actionPath);
criticNetwork = addLayers(criticNetwork,commonPath);
criticNetwork = connectLayers(criticNetwork,'fc1','add/in1');
criticNetwork = connectLayers(criticNetwork,'fc2','add/in2');
criticOptions = rlRepresentationOptions('LearnRate',1e-3,'GradientThreshold',1);
critic1 = rlQValueRepresentation(criticNetwork,observationInfo,actionInfo,...
'Observation',{'State'},'Action',{'Action'},criticOptions);
critic2 = rlQValueRepresentation(criticNetwork,observationInfo,actionInfo,...
'Observation',{'State'},'Action',{'Action'},criticOptions);
actorNetwork = [featureInputLayer(numObservations,'Normalization','none','Name','State')
fullyConnectedLayer(64, 'Name','actorFC1')
reluLayer('Name','relu1')
fullyConnectedLayer(32, 'Name','actorFC2')
reluLayer('Name','relu2')
fullyConnectedLayer(numActions,'Name','Action')
tanhLayer('Name','tanh1')];
actorOptions = rlRepresentationOptions('LearnRate',2e-4,'GradientThreshold',1,'L2RegularizationFactor',0.001);
actor = rlDeterministicActorRepresentation(actorNetwork,observationInfo,actionInfo,...
'Observation',{'State'},'Action',{'tanh1'},actorOptions);
Ts = 0.1;
Tf = 20;
Ts_agent = Ts;
agentOptions = rlTD3AgentOptions("SampleTime",Ts_agent, ...
"DiscountFactor", 0.998, ...
"ExperienceBufferLength",2e6, ...
"MiniBatchSize",64, ...
"NumStepsToLookAhead",1, ...
"TargetSmoothFactor",0.005, ...
"TargetUpdateFrequency",10);
agentOptions.ExplorationModel.Variance = 0.1;
agentOptions.ExplorationModel.VarianceDecayRate = 2e-4;
agentOptions.ExplorationModel.VarianceMin = 0.001;
agentOptions.TargetPolicySmoothModel.Variance = 0.1;
agentOptions.TargetPolicySmoothModel.VarianceDecayRate = 1e-4;
agent = rlTD3Agent(actor,[critic1,critic2],agentOptions);
maxsteps = ceil(Tf/Ts_agent);
trainingOpts = rlTrainingOptions(...
'MaxEpisodes',500,...
'MaxStepsPerEpisode',ceil(Tf/Ts),...
'ScoreAveragingWindowLength',200,...
'Plots','training-progress',...
'StopTrainingCriteria','EpisodeReward',...
'UseParallel',true,...
'StopTrainingValue',100);
trainOpts.UseParallel = true;
trainOpts.ParallelizationOptions.Mode = 'async';
trainOpts.ParallelizationOptions.StepsUntilDataIsSent = 32;
trainOpts.ParallelizationOptions.DataToSendFromWorkers = 'Experiences';
% trainingOpts = rlTrainingOptions(...
% 'MaxEpisodes',maxepisodes, ...
% 'MaxStepsPerEpisode',maxsteps, ...
% 'StopTrainingCriteria','AverageReward',...
% 'StopTrainingValue',-190,...
% 'ScoreAveragingWindowLength',100);
doTraining = true;
if doTraining
trainingStats = train(agent,env,trainingOpts);
else
load('waveconagent.mat',agent)
end
%%
% load('waveconagent.mat')
Ts = 0.1;
Tf = 20;
rpos = [-0.5,0.5];
env = WaveEnvConAction(rpos,true);
simOptions = rlSimulationOptions('MaxSteps',ceil(Tf/Ts));
experience = sim(env,agent,simOptions);
figure
plot(experience.Reward)
figure
% plot(experience.Observation.observation)
% 1.4648e-07
plot(env.HisIntensities(1,:));
figure
plot(env.HisAngle(1,:))
hold on
plot(env.HisAngle(2,:))
legend("yaw","pitch")
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/briefly/rldemo_paper_code.git
git@gitee.com:briefly/rldemo_paper_code.git
briefly
rldemo_paper_code
rldemo_paper_code
master

搜索帮助