[英]Problems with coding Markov Decision Process
我正在嘗試編寫Markov-Decision Process(MDP),我遇到了一些問題。 你可以檢查我的代碼,找出它不起作用的原因
我試圖用一些小數據來做它並且它起作用並給我必要的結果,我覺得這是正確的。 但我的問題是這個代碼的推廣。 是的,我知道MDP庫,但我需要編寫這個代碼。 這段代碼有效,我想在課堂上得到同樣的結果:
import pandas as pd
data = [['3 0', 'UP', 0.6, '3 1', 5, 'YES'], ['3 0', 'UP', 0.4, '3 2', -10, 'YES'], \
['3 0', 'RIGHT', 1, '3 3', 10, 'YES'], ['3 1', 'RIGHT', 1, '3 3', 4, 'NO'], \
['3 2', 'DOWN', 0.6, '3 3', 3, 'NO'], ['3 2', 'DOWN', 0.4, '3 1', 5, 'NO'], \
['3 3', 'RIGHT', 1, 'EXIT', 7, 'NO'], ['EXIT', 'NO', 1, 'EXIT', 0, 'NO']]
df = pd.DataFrame(data, columns = ['Start', 'Action', 'Probability', 'End', 'Reward', 'Policy'], \
dtype = float) #initial matrix
point_3_0, point_3_1, point_3_2, point_3_3, point_EXIT = 0, 0, 0, 0, 0
gamma = 0.9 #it is a discount factor
for i in range(100):
point_3_0 = gamma * max(0.6 * (point_3_1 + 5) + 0.4 * (point_3_2 - 10), point_3_3 + 10)
point_3_1 = gamma * (point_3_3 + 4)
point_3_2 = gamma * (0.6 * (point_3_3 + 3) + 0.4 * (point_3_1 + 5))
point_3_3 = gamma * (point_EXIT + 7)
print(point_3_0, point_3_1, point_3_2, point_3_3, point_EXIT)
但是在這里我有一個錯誤,看起來太復雜了? 你能幫我解決這個問題嗎?!
gamma = 0.9
class MDP:
def __init__(self, gamma, table):
self.gamma = gamma
self.table = table
def Action(self, state):
return self.table[self.table.Start == state].Action.values
def Probability(self, state):
return self.table[self.table.Start == state].Probability.values
def End(self, state):
return self.table[self.table.Start == state].End.values
def Reward(self, state):
return self.table[self.table.Start == state].Reward.values
def Policy(self, state):
return self.table[self.table.Start == state].Policy.values
mdp = MDP(gamma = gamma, table = df)
def value_iteration():
states = mdp.table.Start.values
actions = mdp.Action
probabilities = mdp.Probability
ends = mdp.End
rewards = mdp.Reward
policies = mdp.Policy
V1 = {s: 0 for s in states}
for i in range(100):
V = V1.copy()
for s in states:
if policies(s) == 'YES':
V1[s] = gamma * max(rewards(s) + [sum([p * V[s1] for (p, s1) \
in zip(probabilities(s), ends(s))][actions(s)==a]) for a in set(actions(s))])
else:
sum(probabilities[s] * ends(s))
return V
value_iteration()
我希望每個點都有值,但得到:ValueError:具有多個元素的數組的真值是不明確的。 使用a.any()或a.all()
您收到錯誤,因為policy(s)= ['YES''YES''YES'],因此它包含'YES'三次。 如果要檢查,如果策略中的所有元素都為“是”,則只需將policies(s) == 'YES'
替換為all(x=='YES' for x in policies(s))
如果您只想檢查第一個元素,請更改為policies(s)[0] == 'YES'
請參閱“ 檢查 ”, 檢查列表中的所有元素是否對於不同方法是相同的。
對於描述的第二個問題(假設(policies(s) == YES).any()
修復了第一個問題)注意你用這個表達式初始化一個常規的python列表
[sum([p * V[s1] for (p, s1) in zip(probabilities(s), ends(s))]
然后你嘗試使用索引訪問[actions(s)==a]
python列表不支持多索引,這會導致你遇到的TypeError
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.