[英]N-Queens II using backtracking is slow
n皇后謎題是在nxn棋盤上放置n個皇后以使沒有兩個皇后相互攻擊的問題。
給定一個 integer n,返回 n-queens 謎題的不同解的數量。
https://leetcode.com/problems/n-queens-ii/
我的解決方案:
class Solution:
def totalNQueens(self, n: int) -> int:
def genRestricted(restricted, r, c):
restricted = set(restricted)
for row in range(n): restricted.add((row, c))
for col in range(n): restricted.add((r, col))
movements = [[-1, -1], [-1, 1], [1, -1], [1, 1]]
for movement in movements:
row, col = r, c
while 0 <= row < n and 0 <= col < n:
restricted.add((row, col))
row += movement[0]
col += movement[1]
return restricted
def gen(row, col, curCount, restricted):
count, total_count = curCount, 0
for r in range(row, n):
for c in range(col, n):
if (r, c) not in restricted:
count += 1
if count == n: total_count += 1
total_count += gen(row + 1, 0, count, genRestricted(restricted, r, c))
count -= 1
return total_count
return gen(0, 0, 0, set())
它在 n=8 時失敗。 我不知道為什么,以及如何減少迭代。 看來我已經在做盡可能少的迭代了。
restricted
集似乎在時間和空間上都是浪費的。 在成功遞歸結束時, n
級深它增長到n^2
大小,這將總復雜度驅動到O(n^3)
。 而且它並不是真正需要的。 通過查看已放置的皇后更容易檢查方格的可用性(請原諒國際象棋術語; file
代表垂直, rank
代表水平):
def square_is_safe(file, rank, queens_placed):
for queen_rank, queen_file in enumerate(queens_placed):
if queen_file == file: # vertical attack
return false
if queen_file - file == queen_rank - rank: # diagonal attack
return false
if queen_file - file == rank - queen_rank: # anti-diagonal attack
return false
return true
用於
def place_queen_at_rank(queens_placed, rank):
if rank == n:
total_count += 1
return
for file in range(0, n):
if square_is_safe(file, rank, queens_placed):
queens_placed.append(file)
place_queen_at_rank(queens_placed, rank + 1)
queens_placed.pop()
而且還有很大的優化空間。 例如,您可能希望對第一個等級進行特殊處理:由於對稱性,您只需要檢查其中的一半(將執行時間縮短 2 倍)。
對於 n ≤ 9(鏈接謎題中的界限),枚舉車的所有有效位置並驗證沒有攻擊對角線移動就足夠了。
import itertools
def is_valid(ranks):
return not any(
abs(f1 - f2) == abs(r1 - r2)
for f1, r1 in enumerate(ranks)
for f2, r2 in enumerate(ranks[:f1])
)
def count_valid(n):
return sum(map(is_valid, itertools.permutations(range(n))))
print(*(count_valid(i) for i in range(1, 10)), sep=",")
在這種問題中,您必須首先關注算法,而不是代碼。
下面我將重點介紹算法,僅以C++為例進行說明。
一個主要問題是能夠快速檢測給定的 position 是否已由現有皇后控制。
一種簡單的可能性是索引對角線(對於 0 到 2N-1),如果相應的對角線、對角線或列已經被控制,則在數組中跟蹤。 任何索引對角線或對角線的方法都可以完成這項工作。 對於給定的(row, column)
點,我使用:
diagonal index = row + column
antidiagonal index = n-1 + col - row
另外,我使用了一個簡單的對稱性:只需要計算從0 to n/2-1
的行索引的可能性數量(如果n
為奇數,則為n/2
)。
當然可以通過使用其他對稱性來加快速度。 但是,實際上,對於小於或等於 9 的n
值,它看起來已經足夠快了。
結果:
2 : 0 time : 0.001 ms
3 : 0 time : 0.001 ms
4 : 2 time : 0.001 ms
5 : 10 time : 0.002 ms
6 : 4 time : 0.004 ms
7 : 40 time : 0.015 ms
8 : 92 time : 0.05 ms
9 : 352 time : 0.241 ms
10 : 724 time : 0.988 ms
11 : 2680 time : 5.55 ms
12 : 14200 time : 31.397 ms
13 : 73712 time : 188.12 ms
14 : 365596 time : 1046.43 ms
這是 C++ 中的代碼。 由於代碼非常簡單,您應該可以輕松地將其轉換為 Python。
#include <iostream>
#include <chrono>
constexpr int N_MAX = 14;
constexpr int N_DIAG = 2*N_MAX + 1;
class Solution {
public:
int n;
int Col[N_MAX] = {0};
int Diag[N_DIAG] = {0};
int AntiDiag[N_DIAG] = {0};
int totalNQueens(int n1) {
n = n1;
if (n <= 1) return n;
int count = 0;
for (int col = 0; col < n/2; ++col) {
count += sum_from (0, col);
}
count *= 2;
if (n%2) count += sum_from (0, n/2);
return count;
}
int sum_from (int row, int col) {
if (Col[col]) return 0;
int diag = row + col;
if (Diag[diag]) return 0;
int antidiag = n-1 + col - row;
if(AntiDiag[antidiag]) return 0;
if (row == n-1) return 1;
int count = 0;
Col[col] = 1;
Diag[diag] = 1;
AntiDiag[antidiag] = 1;
for (int k = 0; k < n; ++k) {
count += sum_from (row+1, k);
}
Col[col] = 0;
Diag[diag] = 0;
AntiDiag[antidiag] = 0;
return count;
}
};
int main () {
int n = 1;
while (n++ < N_MAX) {
auto start = std::chrono::high_resolution_clock::now();
Solution Sol;
std::cout << n << " : " << Sol.totalNQueens (n) << " time : ";
auto diff = std::chrono::high_resolution_clock::now() - start;
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(diff).count();
std::cout << double(duration)/1000 << " ms" << std::endl;
}
return 0;
}
您可以通過每行僅放置一個皇后來避免檢查水平沖突。 這還允許您通過僅標記后續行來減小對角沖突矩陣的大小。 對列沖突使用簡單的 boolean 標志列表也可以節省時間(與在矩陣中標記多個條目相反)
這是一個作為解決方案生成器的示例:
def genNQueens(size=8):
# setup queen coverage from each position {position:set of positions}
reach = { (r,c):[] for r in range(size) for c in range(0,size) }
for R in range(size):
for C in range(size):
for h in (1,-1): # diagonals on next rows
reach[R,C].extend((R+i,C+h*i) for i in range(1,size))
reach[R,C] = [P for P in reach[R,C] if P in reach]
reach.update({(r,-1):[] for r in range(size)}) # for unplaced rows
# place 1 queen on each row, with backtracking
cols = [-1]*size # column of each queen (start unplaced)
usedCols = [False]*(size+1) # column conflict detection
usedDiag = [[0]*(size+1) for _ in range(size+1)] # for diagonal conflicts
r = 0
while r >= 0:
usedCols[cols[r]] = False
for ur,uc in reach[r,cols[r]]: usedDiag[ur][uc] -= 1
cols[r] = next((c for c in range(cols[r]+1,size)
if not usedCols[c] and not usedDiag[r][c]),-1)
usedCols[cols[r]] = True
for ur,uc in reach[r,cols[r]]: usedDiag[ur][uc] += 1
r += 1 if cols[r]>=0 else -1 # progress or backtrack
if r<size : continue # continue until all rows placed
yield [*enumerate(cols)] # return result
r -= 1 # backtrack to find more
output:
from timeit import timeit
for n in range(3,13):
t = timeit(lambda:sum(1 for _ in genNQueens(n)), number=1)
c = sum(1 for _ in genNQueens(n))
print(f"solutions for {n}x{n}:", c, "time:",f"{t:.4g}")
solutions for 3x3: 0 time: 0.000108
solutions for 4x4: 2 time: 0.0002044
solutions for 5x5: 10 time: 0.0004365
solutions for 6x6: 4 time: 0.0008741
solutions for 7x7: 40 time: 0.003386
solutions for 8x8: 92 time: 0.009881
solutions for 9x9: 352 time: 0.03402
solutions for 10x10: 724 time: 0.1228
solutions for 11x11: 2680 time: 0.5707
solutions for 12x12: 14200 time: 2.77
只需一項更改(刪除gen
中的 r 循環)即可使您的解決方案成為交流電。
主要原因是您的gen
有參數row
,它將使用row + 1
調用自身,因此無需for r in range(row, n):
。 這是不必要的。 只是刪除它,您的解決方案是完全可以接受的。(我們需要在嵌套調用之前添加else
)
結果如下: 更改前:
1 1 1.8358230590820312e-05
2 0 5.7697296142578125e-05
3 0 0.00036835670471191406
4 2 0.0021448135375976562
5 10 0.02212214469909668
6 4 0.23602914810180664
7 40 3.0731561183929443
更改后:
1 1 1.6450881958007812e-05
2 0 3.1948089599609375e-05
3 0 0.0001366138458251953
4 2 0.0002281665802001953
5 10 0.0008234977722167969
6 4 0.0028502941131591797
7 40 0.01242375373840332
8 92 0.05443763732910156
9 352 0.2279810905456543
對於 n = 7 的情況,它只使用原始版本的0.4%時間,並且 n = 8 絕對可以工作。
class Solution:
def totalNQueens(self, n: int) -> int:
def genRestricted(restricted, r, c):
restricted = set(restricted)
for row in range(n): restricted.add((row, c))
for col in range(n): restricted.add((r, col))
movements = [[-1, -1], [-1, 1], [1, -1], [1, 1]]
for movement in movements:
row, col = r, c
while 0 <= row < n and 0 <= col < n:
restricted.add((row, col))
row += movement[0]
col += movement[1]
return restricted
def gen(row, col, curCount, restricted):
count, total_count = curCount, 0
for c in range(col, n):
if (row, c) not in restricted:
count += 1
if count == n: total_count += 1
else: total_count += gen(row + 1, 0, count, genRestricted(restricted, row, c))
count -= 1
return total_count
return gen(0, 0, 0, set())
if __name__ == '__main__':
import time
s = Solution()
for i in range(1, 8):
t0 = time.time()
print(i, s.totalNQueens(i), '\t', time.time() - t0)
當然,還可以進行其他增強。 但這是最大的一個。
例如,您在添加每個點后更新並創建了一個新的限制/禁止點。 順便說一句,我不同意 @user58697 的restricted
,根據您的解決方案,這是必要的,因為您需要克隆和更新以獲得新的,以避免在遞歸調用循環中恢復它。
順便說一句,以下是我的解決方案,僅供您參考:
class Solution:
def solveNQueens_n(self, n): #: int) -> List[List[str]]:
cols = [-1] * n # index means row index
self.res = 0
usedCols = set() # this and cols can avoid vertical and horizontal conflict
def dfs(r): # current row to fill in
def valid(c):
for r0 in range(r):
# (r0, c0), (r1, c1) in the (back-)diagonal, |r1 - r0| = |c1 - c0|
if abs(c - cols[r0]) == abs(r - r0):
return False
return True
if r == n: # valid answer
self.res += 1
return
for c in range(n):
if c not in usedCols and valid(c):
usedCols.add(c)
cols[r] = c
dfs(r + 1)
usedCols.remove(c)
cols[r] = -1
dfs(0)
return self.res
好吧,我錯過的一件事是每一行都必須有一個女王。 非常重要的觀察。 gen 方法必須像這樣修改:
def gen(row, col, curCount, restricted):
if row == n: return 0
count, total_count = curCount, 0
for c in range(col, n):
if (row, c) not in restricted:
if count + 1 == n: total_count += 1
total_count += gen(row + 1, 0, count + 1, genRestricted(restricted, row, c))
return total_count
它只擊敗了約 20% 的提交,所以它一點也不完美。 離得很遠。
僅供參考,這是我遇到的該平台上最快的(在Python
類別中,它擊敗了 97% 的提交)。 但是,它需要一些深入的潛水......
def totalNQueens(self, n: int) -> int:
def dfs(n, row, col, primary, secondary):
nonlocal ans
if row == n:
ans += 1
return
available = ~(col | primary | secondary) & ((1 << n) - 1)
while available:
p = available & -available
available ^= p
dfs(n, row+1, col|p, (primary|p) >> 1, (secondary|p) << 1)
ans = 0
dfs(n, 0, 0, 0, 0)
return ans
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.