简体   繁体   English

如何使用动态规划解决这个问题?

[英]How can I solve this problem using dynamic programming?

Given a list of numbers, say [4 5 2 3], I need to maximize the sum obtained according to the following set of rules:给定一个数字列表,比如 [4 5 2 3],我需要根据以下规则集最大化获得的总和:

  1. I need to select a number from the list and that number will be removed.我需要 select 列表中的一个数字,该数字将被删除。 Eg.例如。 selecting 2 will have the list as [4 5 3].选择 2 将使列表为 [4 5 3]。
  2. If the number to be removed has two neighbours then I should get the result of this selection as the product of the currently selected number with one of its neighbours and this product summed up with the other neighbour.如果要删除的数字有两个邻居,那么我应该得到这个选择的结果,作为当前选择的数字与其一个邻居的乘积,这个乘积与另一个邻居相加。 eg.: if I select 2 then I can have the result of this selction as 2 * 5 + 3.例如:如果我 select 2 那么我可以得到这个选择的结果为 2 * 5 + 3。
  3. If I select a number with only one neighbour then the result is the product of the selected number with its neighbour.如果我 select 一个数字只有一个邻居,那么结果是所选数字与其邻居的乘积。
  4. When their is only one number left then it is just added to the result till now.当他们只剩下一个数字时,它只是添加到现在的结果中。

Following these rules, I need to select the numbers in such an order that the result is maximized.遵循这些规则,我需要 select 以结果最大化的顺序排列数字。

For the above list, if the order of selction is 4->2->3->5 then the sum obtained is 53 which is the maximum.对于上面的列表,如果选择的顺序是 4->2->3->5,那么得到的总和是 53,这是最大值。

I am including a program which lets you pass as input the set of elements and gives all possible sums and also indicates the max sum.我包含一个程序,它可以让您将元素集作为输入传递,并给出所有可能的总和,并指示最大总和。

Here's a link .这是一个链接

import itertools

l = [int(i) for i in input().split()]
p = itertools.permutations(l) 

c, cs = 1, -1
mm = -1
for i in p:
    var, s = l[:], 0
    print(c, ':', i)
    c += 1
    
    for j in i:
        print(' removing: ', j)
        pos = var.index(j)
        if pos == 0 or pos == len(var) - 1:
            if pos == 0 and len(var) != 1:
                s += var[pos] * var[pos + 1]
                var.remove(j)
            elif pos == 0 and len(var) == 1:
                s += var[pos]
                var.remove(j)
            if pos == len(var) - 1 and pos != 0:
                s += var[pos] * var[pos - 1]
                var.remove(j)
        else:
            mx = max(var[pos - 1], var[pos + 1])
            mn = min(var[pos - 1], var[pos + 1])
            
            s += var[pos] * mx + mn
            var.remove(j)
        
        if s > mm:
            mm = s
            cs = c - 1
        print(' modified list: ', var, '\n  sum:', s)

print('MAX SUM was', mm, ' at', cs)

Consider 4 variants of the problem: those where every element gets consumed, and those where either the left, the right, or both the right and left elements are not consumed.考虑该问题的 4 个变体:每个元素都被消耗的那些,以及左侧、右侧或左右两个元素都未被消耗的那些。

In each case, you can consider the last element to be removed, and this breaks the problem down into 1 or 2 subproblems.在每种情况下,您都可以考虑要删除的最后一个元素,这会将问题分解为 1 或 2 个子问题。

This solves the problem in O(n^3) time.这在 O(n^3) 时间内解决了问题。 Here's a python program that solves the problem.这是一个解决问题的 python 程序。 The 4 variants of solve_ correspond to none, one or the other, or both of the endpoints being fixed. solve_的 4 个变体对应于无、一个或另一个,或两个端点都被固定。 No doubt this program can be reduced (there's a lot of duplication).毫无疑问,这个程序可以减少(有很多重复)。

def solve_00(seq, n, m, cache):
    key = ('00', n, m)
    if key in cache:
        return cache[key]
    assert m >= n
    if n == m:
        return seq[n]
    best = -1e9
    for i in range(n, m+1):
        left = solve_01(seq, n, i, cache) if i > n else 0
        right = solve_10(seq, i, m, cache) if i < m else 0
        best = max(best, left + right + seq[i])
    cache[key] = best
    return best


def solve_01(seq, n, m, cache):
    key = ('01', n, m)
    if key in cache:
        return cache[key]
    assert m >= n + 1
    if m == n + 1:
        return seq[n] * seq[m]
    best = -1e9
    for i in range(n, m):
        left = solve_01(seq, n, i, cache) if i > n else 0
        right = solve_11(seq, i, m, cache) if i < m - 1 else 0
        best = max(best, left + right + seq[i] * seq[m])
    cache[key] = best
    return best

def solve_10(seq, n, m, cache):
    key = ('10', n, m)
    if key in cache:
        return cache[key]
    assert m >= n + 1
    if m == n + 1:
        return seq[n] * seq[m]
    best = -1e9
    for i in range(n+1, m+1):
        left = solve_11(seq, n, i, cache) if i > n + 1 else 0
        right = solve_10(seq, i, m, cache) if i < m else 0
        best = max(best, left + right + seq[n] * seq[i])
    cache[key] = best
    return best

def solve_11(seq, n, m, cache):
    key = ('11', n, m)
    if key in cache:
        return cache[key]   
    assert m >= n + 2
    if m == n + 2:
        return max(seq[n] * seq[n+1] + seq[n+2], seq[n] + seq[n+1] * seq[n+2])
    best = -1e9
    for i in range(n + 1, m):
        left = solve_11(seq, n, i, cache) if i > n + 1 else 0
        right = solve_11(seq, i, m, cache) if i < m - 1 else 0
        best = max(best, left + right + seq[i] * seq[n] + seq[m], left + right + seq[i] * seq[m] + seq[n])
    cache[key] = best
    return best

for c in [[1, 1, 1], [4, 2, 3, 5], [1, 2], [1, 2, 3], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]:
    print(c, solve_00(c, 0, len(c)-1, dict()))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM