簡體   English   中英

Backward algorithm Hidden Markov Model, 0th index (termination step) 產生錯誤結果

[英]Backward algorithm Hidden Markov Model, 0th index (termination step) yields wrong result

我正在 PyTorch 中實現后向 HMM 算法。 我用這個鏈接作為參考。 此鏈接包含所用數值示例的結果(我正在嘗試實現它並將生成的結果與它進行比較)。 第 3 頁,第 2 節。后向概率,有一個包含計算結果的表格。

這是我的代碼:

# Initial Transition matrix as shown in page 2 of above link
A = np.array([[0.6, 0.4], [0.3, 0.7]])
A = torch.from_numpy(A)
# Initial State Probability (page 2)
pi = np.array([0.8, 0.2])
pi = torch.from_numpy(pi)
# Output probabilities (page 2)
emission_matrix = np.array([[0.3, 0.4, 0.3, 0.3], [0.4, 0.3, 0.3, 0.3]])
emission_matrix = torch.from_numpy(emission_matrix)
# Initialize empty 2x4 matrix (dimensions of emission matrix)
backward = torch.zeros(emission_matrix.shape, dtype=torch.float64)

# Backward algorithm
def _backward(emission_matrix):       
    # Initialization: A(i, j) * B(T, i) * B(Ot+1, j) , where B(Ot+1, j)  = 1
    backward[:, -1] = torch.matmul(A, emission_matrix[:, -1])
    # I reversed the emission matrix so as to start from the last column
    rev_emission_mat = torch.flip(emission_matrix[:, :-1], [1])
    # I transposed the reversed emission matrix such that each iterable in the for 
    # loop is the observation sequence probability
    T_rev_emission_mat = torch.transpose(rev_emission_mat, 1, 0)
    # This step is so that I assign a reverse index enumeration to each iterable in the
    # emission matrix starts from time T to 0, rather than the opposite
    zipped_cols = list(zip(range(len(T_rev_emission_mat)-1, -1, -1), T_rev_emission_mat))

    for i, obs_prob in zipped_cols:
        # Induction: Σ A(i, j) * B(j)(Ot+1) * β(t+1, j)   
        if i != 0:
            backward[:, i] = torch.matmul(A * obs_prob, backward[:, i+1])      
    # Termination: Σ π(i) * bi * β(1, i)
    backward[:, 0] = torch.matmul(pi * obs_prob, backward[:, 1])

# run backward algorithm
_backward(emission_matrix)
# check results, backward is an all zero matrix that was initialized above
print(backward)
>>> tensor([[0.0102, 0.0324, 0.0900, 0.3000],
           [0.0102, 0.0297, 0.0900, 0.3000]], dtype=torch.float64)

如您所見,第 0 個索引與上一個鏈接的第 3 頁中的結果不匹配。 我做錯了什么? 如果有什么我可以澄清的,請告訴我。 提前致謝!

backward[:, 0] = pi * obs_prob * backward[:, 1]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM