簡體   English   中英

我應該如何使用 tf.scan 來實現這種遞推關系?

[英]How should I use tf.scan to implement this recurrence relation?

我想在 Tensorflow 中實現以下功能。

輸入是一個形狀為(n, 1)的張量x

輸出是形狀為(n, 1)的張量y

我固定了a形狀為(2, 1)的張量a

然后我想通過以下公式計算y

y[0] = x[0]
y[1] = x[1] - a[0] * x[0]
y[2] = x[2] - a[0] * x[1] - a[1] * x[0]
... 
y[i] = x[i] - a[0] * x[i - 1] - a[1] * x[i - 2]

我希望a的條目是可訓練的。

感謝您的任何幫助,您可以提供!

編輯:我正在嘗試一個可能如下所示的解決方案:

elems = np.array([1, 0, 1, 0, 1, 1, 1, 0, 1, 0])

init = (np.array([0, 0]), 0)

# a = initialize
# y = elems[0]

# test first step
# print((np.array([a[0][1], a[1]]), a[0][0] + a[0][1] + y))

a1 = tf.Variable(-1)
a2 = tf.Variable(-1)

def f(x, y):
    return (tf.stack([x[0][1], x[1]], 0), y - a1 * x[0][1] - a2 * x[0][1])

c = tf.scan(f, elems, initializer=init)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(c))

其中elems是輸入xc的最后一列應該是y 還不行,答案應該是

1 1 3 4 8 13 22 35 58 93

我能夠解決這個問題。

elems = np.array([1, 0, 1, 0, 1, 1, 1, 0, 1, 0])

init = (np.array([0, 0]), 0)

a1 = tf.Variable(-1)
a2 = tf.Variable(-1)

def f(x, y):
    return (tf.stack([x[0][1], x[1]], 0), y - a2 * x[0][1] - a1 * x[1])
    
c = tf.scan(f, elems, initializer=init)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(c))

為了解釋 tf.scan 在這種情況下的工作原理,您可以查看以下斐波那契數列的示例 MATLAB 代碼。

% tensor to feed in helper information at each step
elems = [1, 0, 0, 0, 0, 0]; 

% initialize memory
initialize = [0, 1]; 

% define the function f(current_memory, current_element) = next_memory
f = @(a, y) [a(2), a(1) + a(2)];  
% in this case, the element is not used, but in the previous example, it is

% memory storage tensor (Fib result is second column)
x = zeros(length(elems), length(initialize)); 

% run tf.scan
x(1, :) = f(initialize, elems(1));
for i = 2:length(elems)
    x(i, :) = f(x(i - 1, :), elems(i));
end

函數 f(a, y) 有兩個參數,a:“記憶”張量,y:elems 張量的當前條目。 在這個斐波那契案例中,f 忽略了 elems 張量,但在上面的例子中,它使用了它。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM