1 Star 0 Fork 0

jacinth2006/MDP

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
policy_iteration.py 2.68 KB
一键复制 编辑 原始数据 按行查看 历史
#%%
"""
S H H
F F F
H H G
"""
import numpy as np
#up down left right
actions=np.arange(4)
status=np.arange(9)
gama=0.8
#构造转移矩阵
p=[[0 for i in range(4)] for j in range(9)]
p[0][0]=[[0.7,0,0,False],[0.3,3,0,False]]
p[0][1]=[[0.3,0,0,False],[0.7,3,0,False]]
p[0][2]=[[0.7,0,0,False],[0.3,1,0,True]]
p[0][3]=[[0.3,0,0,False],[0.7,1,0,True]]
p[1][0]=[[1,1,-1,True]]
p[1][1]=[[1,1,-1,True]]
p[1][2]=[[1,1,-1,True]]
p[1][3]=[[1,1,-1,True]]
p[2][0]=[[1,2,-1,True]]
p[2][1]=[[1,2,-1,True]]
p[2][2]=[[1,2,-1,True]]
p[2][3]=[[1,2,-1,True]]
p[3][0]=[[0.5,0,0,False],[0.3,3,0,False],[0.2,6,-1,True]]
p[3][1]=[[0.5,6,-1,True],[0.3,3,0,False],[0.2,0,0,False]]
p[3][2]=[[0.7,3,0,False],[0.3,4,0,False]]
p[3][3]=[[0.3,3,0,False],[0.7,4,0,False]]
p[4][0]=[[0.5,1,-1,True],[0.3,4,0,False],[0.2,7,-1,True]]
p[4][1]=[[0.5,7,-1,True],[0.3,4,0,False],[0.2,1,-1,True]]
p[4][2]=[[0.7,3,0,False],[0.3,4,0,False],[0.2,5,0,False]]
p[4][3]=[[0.7,5,0,False],[0.3,4,0,False],[0.2,3,0,False]]
p[5][0]=[[0.5,2,-1,True],[0.3,5,0,False],[0.2,8,5,True]]
p[5][1]=[[0.5,8,5,True],[0.3,5,0,False],[0.2,2,-1,True]]
p[5][2]=[[0.7,4,0,False],[0.3,5,0,False]]
p[5][3]=[[0.7,5,0,False],[0.3,4,0,False]]
p[6][0]=[[1,6,-1,True]]
p[6][1]=[[1,6,-1,True]]
p[6][2]=[[1,6,-1,True]]
p[6][3]=[[1,6,-1,True]]
p[7][0]=[[1,7,-1,True]]
p[7][1]=[[1,7,-1,True]]
p[7][2]=[[1,7,-1,True]]
p[7][3]=[[1,7,-1,True]]
p[8][0]=[[1,8,5,True]]
p[8][1]=[[1,8,5,True]]
p[8][2]=[[1,8,5,True]]
p[8][3]=[[1,8,5,True]]
#old_policy=[2 for i in range(9)]
#old_policy=np.array(old_policy)
old_policy=np.random.randint(4,size=9)
print("init policy",old_policy)
n_evaluate_policy_iteration=1000
n_policy_iteration=1000
def evaluate_policy(policy):
value_table=np.zeros(9)
for i in range(n_evaluate_policy_iteration):
value_tmp_table=np.copy(value_table)
for s in status:
tmp=0
for prob,next_st,reward,_ in p[s][policy[s]]:
tmp+=prob*(reward+gama*value_tmp_table[next_st])
value_table[s]=tmp
return value_table
def extract_policy(value_table):
every_status_action_tmp_q=np.zeros((9,4))
for s in status:
for a in actions:
for prob,next_st,reward,_ in p[s][a]:
every_status_action_tmp_q[s][a]+=prob*(reward+gama*value_table[next_st])
policy=np.argmax(every_status_action_tmp_q,axis=1)
return policy
for i in range(n_policy_iteration):
value_table=evaluate_policy(old_policy)
new_policy=extract_policy(value_table)
com_policy=(new_policy==old_policy)
if(com_policy.all()):
break
old_policy=new_policy
print(new_policy)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/jacinth2006/MDP.git
git@gitee.com:jacinth2006/MDP.git
jacinth2006
MDP
MDP
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385