簡體   English   中英

在張量流,高級索引中更新矩陣變量的值

[英]Update values of a matrix variable in tensorflow, advanced indexing

我想創建一個函數,對於給定數據X的每一行,僅將softmax函數應用於K個總類中的某些采樣類,比方說2個。 在簡單的python中,代碼看起來像這樣:

def softy(X,W, num_samples):
    N = X.shape[0]
    K = W.shape[0]
    S = np.zeros((N,K)) 
    ar_to_sof = np.zeros(num_samples)
    sampled_ind = np.zeros(num_samples, dtype = int)
    for line in range(N):        
        for samp in range(num_samples):
            sampled_ind[samp] = randint(0,K-1)
            ar_to_sof[samp] = np.dot(X[line],np.transpose(W[sampled_ind[samp]])) 
        ar_to_sof = softmax(ar_to_sof)
        S[line][sampled_ind] = ar_to_sof 

    return S

S最終將在數組“ samped_ind”為每一行定義的索引中包含零和非零值。 我想使用Tensorflow來實現這一點。 問題是它包含“高級”索引,而我找不到使用此庫來創建該索引的方法。

我正在嘗試使用此代碼:

S = tf.Variable(tf.zeros((N,K)))
tfx = tf.placeholder(tf.float32,shape=(None,D))
wsampled = tf.placeholder(tf.float32, shape = (None,D))
ar_to_sof = tf.matmul(tfx,wsampled,transpose_b=True)
softy = tf.nn.softmax(ar_to_sof)
r = tf.random_uniform(shape=(), minval=0,maxval=K, dtype=tf.int32)
...
for line in range(N):
    sampled_ind = tf.constant(value=[sess.run(r),sess.run(r)],dtype= tf.int32)
    Wsampled = sess.run(tf.gather(W,sampled_ind))
    sess.run(softy,feed_dict={tfx:X[line:line+1], wsampled:Wsampled})

一切工作到這里為止,但是我找不到在python S中用python代碼“ S [line] [sampled_ind] = ar_to_sof”進行更新的方法。

我該如何進行這項工作?

對該問題的解決方案的評論中找到了我的問題的答案。 建議將矩陣S重塑為1d向量。這樣,代碼即可正常工作,看起來像:

S = tf.Variable(tf.zeros(shape=(N*K)))
W = tf.Variable(tf.random_uniform((K,D)))
tfx = tf.placeholder(tf.float32,shape=(None,D))
sampled_ind = tf.random_uniform(dtype=tf.int32, minval=0, maxval=K-1, shape=[num_samps])
ar_to_sof = tf.matmul(tfx,tf.gather(W,sampled_ind),transpose_b=True)
updates = tf.reshape(tf.nn.softmax(ar_to_sof),shape=(num_samps,))
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for line in range(N):
    inds_new = sampled_ind + line*K
    sess.run(tf.scatter_update(S,inds_new,updates), feed_dict={tfx: X[line:line+1]})

S = tf.reshape(S,shape=(N,K))

那返回了我期望的結果。 現在的問題是此實現速度太慢。 比numpy版本慢得多。 也許是for循環。 有什么建議么?

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM