簡體   English   中英

如何使用動態規划解決這個問題?

[英]How can I solve this problem using dynamic programming?

給定一個數字列表,比如 [4 5 2 3],我需要根據以下規則集最大化獲得的總和:

  1. 我需要 select 列表中的一個數字,該數字將被刪除。 例如。 選擇 2 將使列表為 [4 5 3]。
  2. 如果要刪除的數字有兩個鄰居,那么我應該得到這個選擇的結果,作為當前選擇的數字與其一個鄰居的乘積,這個乘積與另一個鄰居相加。 例如:如果我 select 2 那么我可以得到這個選擇的結果為 2 * 5 + 3。
  3. 如果我 select 一個數字只有一個鄰居,那么結果是所選數字與其鄰居的乘積。
  4. 當他們只剩下一個數字時,它只是添加到現在的結果中。

遵循這些規則,我需要 select 以結果最大化的順序排列數字。

對於上面的列表,如果選擇的順序是 4->2->3->5,那么得到的總和是 53,這是最大值。

我包含一個程序,它可以讓您將元素集作為輸入傳遞,並給出所有可能的總和,並指示最大總和。

這是一個鏈接

import itertools

l = [int(i) for i in input().split()]
p = itertools.permutations(l) 

c, cs = 1, -1
mm = -1
for i in p:
    var, s = l[:], 0
    print(c, ':', i)
    c += 1
    
    for j in i:
        print(' removing: ', j)
        pos = var.index(j)
        if pos == 0 or pos == len(var) - 1:
            if pos == 0 and len(var) != 1:
                s += var[pos] * var[pos + 1]
                var.remove(j)
            elif pos == 0 and len(var) == 1:
                s += var[pos]
                var.remove(j)
            if pos == len(var) - 1 and pos != 0:
                s += var[pos] * var[pos - 1]
                var.remove(j)
        else:
            mx = max(var[pos - 1], var[pos + 1])
            mn = min(var[pos - 1], var[pos + 1])
            
            s += var[pos] * mx + mn
            var.remove(j)
        
        if s > mm:
            mm = s
            cs = c - 1
        print(' modified list: ', var, '\n  sum:', s)

print('MAX SUM was', mm, ' at', cs)

考慮該問題的 4 個變體:每個元素都被消耗的那些,以及左側、右側或左右兩個元素都未被消耗的那些。

在每種情況下,您都可以考慮要刪除的最后一個元素,這會將問題分解為 1 或 2 個子問題。

這在 O(n^3) 時間內解決了問題。 這是一個解決問題的 python 程序。 solve_的 4 個變體對應於無、一個或另一個,或兩個端點都被固定。 毫無疑問,這個程序可以減少(有很多重復)。

def solve_00(seq, n, m, cache):
    key = ('00', n, m)
    if key in cache:
        return cache[key]
    assert m >= n
    if n == m:
        return seq[n]
    best = -1e9
    for i in range(n, m+1):
        left = solve_01(seq, n, i, cache) if i > n else 0
        right = solve_10(seq, i, m, cache) if i < m else 0
        best = max(best, left + right + seq[i])
    cache[key] = best
    return best


def solve_01(seq, n, m, cache):
    key = ('01', n, m)
    if key in cache:
        return cache[key]
    assert m >= n + 1
    if m == n + 1:
        return seq[n] * seq[m]
    best = -1e9
    for i in range(n, m):
        left = solve_01(seq, n, i, cache) if i > n else 0
        right = solve_11(seq, i, m, cache) if i < m - 1 else 0
        best = max(best, left + right + seq[i] * seq[m])
    cache[key] = best
    return best

def solve_10(seq, n, m, cache):
    key = ('10', n, m)
    if key in cache:
        return cache[key]
    assert m >= n + 1
    if m == n + 1:
        return seq[n] * seq[m]
    best = -1e9
    for i in range(n+1, m+1):
        left = solve_11(seq, n, i, cache) if i > n + 1 else 0
        right = solve_10(seq, i, m, cache) if i < m else 0
        best = max(best, left + right + seq[n] * seq[i])
    cache[key] = best
    return best

def solve_11(seq, n, m, cache):
    key = ('11', n, m)
    if key in cache:
        return cache[key]   
    assert m >= n + 2
    if m == n + 2:
        return max(seq[n] * seq[n+1] + seq[n+2], seq[n] + seq[n+1] * seq[n+2])
    best = -1e9
    for i in range(n + 1, m):
        left = solve_11(seq, n, i, cache) if i > n + 1 else 0
        right = solve_11(seq, i, m, cache) if i < m - 1 else 0
        best = max(best, left + right + seq[i] * seq[n] + seq[m], left + right + seq[i] * seq[m] + seq[n])
    cache[key] = best
    return best

for c in [[1, 1, 1], [4, 2, 3, 5], [1, 2], [1, 2, 3], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]:
    print(c, solve_00(c, 0, len(c)-1, dict()))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM