代码拉取完成,页面将自动刷新
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import json
import os
import pickle
import gym
import ray
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.episode import _flatten_action
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.tune.util import merge_dicts
from ray.tune.registry import register_env
import envs_launcher as el
EXAMPLE_USAGE = """
Example Usage via RLlib CLI:
rllib rollout /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN
--env CartPole-v0 --steps 1000000 --out rollouts.pkl
Example Usage via executable:
./rollout.py /tmp/ray/checkpoint_dir/checkpoint-0 --run DQN
--env CartPole-v0 --steps 1000000 --out rollouts.pkl
"""
# Note: if you use any custom models or envs, register them here first, e.g.:
#
# ModelCatalog.register_custom_model("pa_model", ParametricActionsModel)
# register_env("pa_cartpole", lambda _: ParametricActionCartpole(10))
def create_parser(parser_creator=None):
parser_creator = parser_creator or argparse.ArgumentParser
parser = parser_creator(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="Roll out a reinforcement learning agent "
"given a checkpoint.",
epilog=EXAMPLE_USAGE)
parser.add_argument(
"checkpoint", type=str, help="Checkpoint from which to roll out.")
required_named = parser.add_argument_group("required named arguments")
# required_named.add_argument(
# "--run",
# type=str,
# required=True,
# help="The algorithm or model to train. This may refer to the name "
# "of a built-on algorithm (e.g. RLLib's DQN or PPO), or a "
# "user-defined trainable function or class registered in the "
# "tune registry.")
required_named.add_argument(
"--env", type=str, help="The gym environment to use.")
parser.add_argument(
"--no-render",
default=False,
action="store_const",
const=True,
help="Surpress rendering of the environment.")
parser.add_argument(
"--steps", default=10000, help="Number of steps to roll out.")
parser.add_argument(
"--episodes",
default=1,
type=int,
help="Number of episodes to roll out.")
parser.add_argument("--out", default=None, help="Output filename.")
parser.add_argument(
"--config",
default="{}",
type=json.loads,
help="Algorithm-specific configuration (e.g. env, hyperparams). "
"Surpresses loading of configuration from checkpoint.")
return parser
def run(args, parser):
config = {}
# Load configuration from file
config_dir = os.path.dirname(args.checkpoint)
config_path = os.path.join(config_dir, "params.pkl")
if not os.path.exists(config_path):
config_path = os.path.join(config_dir, "../params.pkl")
if not os.path.exists(config_path):
if not args.config:
raise ValueError(
"Could not find params.pkl in either the checkpoint dir or "
"its parent directory.")
else:
with open(config_path, "rb") as f:
config = pickle.load(f)
if "num_workers" in config:
config["num_workers"] = min(2, config["num_workers"])
config = merge_dicts(config, args.config)
if not args.env:
if not config.get("env"):
parser.error("the following arguments are required: --env")
args.env = config.get("env")
# remove unnecessary parameters
if "num_workers" in config:
del config["num_workers"]
if "human_data_dir" in config["optimizer"]:
del config["optimizer"]["human_data_dir"]
if "human_demonstration" in config["optimizer"]:
del config["optimizer"]["human_demonstration"]
if "multiple_human_data" in config["optimizer"]:
del config["optimizer"]["multiple_human_data"]
if "num_replay_buffer_shards" in config["optimizer"]:
del config["optimizer"]["num_replay_buffer_shards"]
if "demonstration_zone_percentage" in config["optimizer"]:
del config["optimizer"]["demonstration_zone_percentage"]
if "dynamic_experience_replay" in config["optimizer"]:
del config["optimizer"]["dynamic_experience_replay"]
if "robot_demo_path" in config["optimizer"]:
del config["optimizer"]["robot_demo_path"]
ray.init()
# cls = get_agent_class(args.run)
# agent = cls(env=args.env, config=config)
cls = get_agent_class("DDPG")
agent = cls(env="ROBOTIC_ASSEMBLY", config=config)
agent.restore(args.checkpoint)
num_steps = int(args.steps)
num_episodes = int(args.episodes)
rollout(agent, args.env, num_steps, num_episodes, args.out)
class DefaultMapping(collections.defaultdict):
"""default_factory now takes as an argument the missing key."""
def __missing__(self, key):
self[key] = value = self.default_factory(key)
return value
def default_policy_agent_mapping(unused_agent_id):
return DEFAULT_POLICY_ID
def rollout(agent, env_name, num_steps, num_episodes, out=None):
policy_agent_mapping = default_policy_agent_mapping
if hasattr(agent, "workers"):
env = agent.workers.local_worker().env
multiagent = isinstance(env, MultiAgentEnv)
if agent.workers.local_worker().multiagent:
policy_agent_mapping = agent.config["multiagent"][
"policy_mapping_fn"]
policy_map = agent.workers.local_worker().policy_map
state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
action_init = {
p: _flatten_action(m.action_space.sample())
for p, m in policy_map.items()
}
else:
env = gym.make(env_name)
multiagent = False
use_lstm = {DEFAULT_POLICY_ID: False}
if out is not None:
rollouts = []
steps = 0
episodes = 0
while steps < (num_steps or steps + 1) and (episodes < num_episodes):
mapping_cache = {} # in case policy_agent_mapping is stochastic
if out is not None:
rollout = []
obs = env.reset()
agent_states = DefaultMapping(
lambda agent_id: state_init[mapping_cache[agent_id]])
prev_actions = DefaultMapping(
lambda agent_id: action_init[mapping_cache[agent_id]])
prev_rewards = collections.defaultdict(lambda: 0.)
done = False
reward_total = 0.0
while not done and steps < (num_steps or steps + 1):
multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs}
action_dict = {}
for agent_id, a_obs in multi_obs.items():
if a_obs is not None:
policy_id = mapping_cache.setdefault(
agent_id, policy_agent_mapping(agent_id))
p_use_lstm = use_lstm[policy_id]
if p_use_lstm:
a_action, p_state, _ = agent.compute_action(
a_obs,
state=agent_states[agent_id],
prev_action=prev_actions[agent_id],
prev_reward=prev_rewards[agent_id],
policy_id=policy_id)
agent_states[agent_id] = p_state
else:
a_action = agent.compute_action(
a_obs,
prev_action=prev_actions[agent_id],
prev_reward=prev_rewards[agent_id],
policy_id=policy_id)
a_action = _flatten_action(a_action) # tuple actions
action_dict[agent_id] = a_action
prev_actions[agent_id] = a_action
action = action_dict
action = action if multiagent else action[_DUMMY_AGENT_ID]
next_obs, reward, done, _ = env.step(action)
if multiagent:
for agent_id, r in reward.items():
prev_rewards[agent_id] = r
else:
prev_rewards[_DUMMY_AGENT_ID] = reward
if multiagent:
done = done["__all__"]
reward_total += sum(reward.values())
else:
reward_total += reward
if out is not None:
rollout.append([obs, action, next_obs, reward, done])
steps += 1
obs = next_obs
if out is not None:
rollouts.append(rollout)
print("Episode reward", reward_total)
episodes += 1
if out is not None:
pickle.dump(rollouts, open(out, "wb"))
if __name__ == "__main__":
register_env("ROBOTIC_ASSEMBLY", el.env_creator)
parser = create_parser()
args = parser.parse_args()
run(args, parser)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。