簡體   English   中英

如何更新 Tensorflow 中二維張量的子集?

[英]How to update a subset of 2D tensor in Tensorflow?

我想用值 0 更新 2D 張量中的索引。因此數據是一個 2D 張量,其第 2 行第 2 列索引值將替換為 0。但是,我收到類型錯誤。 任何人都可以幫我嗎?

TypeError: 'ScatterUpdate' Op 的輸入 'ref' 需要左值輸入

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
data2 = tf.reshape(data, [-1])
sparse_update = tf.scatter_update(data2, tf.constant([7]), tf.constant([0]))
#data = tf.reshape(data, [N,S])
init_op = tf.initialize_all_variables()

sess = tf.Session()
sess.run([init_op])
print "Values before:", sess.run([data])
#sess.run([updated_data_subset])
print "Values after:", sess.run([sparse_update])

散點更新僅適用於變量。 而是嘗試這種模式。

Tensorflow 版本 < 1.0: a = tf.concat(0, [a[:i], [updated_value], a[i+1:]])

Tensorflow 版本 >= 1.0: a = tf.concat(axis=0, values=[a[:i], [updated_value], a[i+1:]])

tf.scatter_update只能應用於Variable類型。 代碼中的dataVariable ,而data2不是,因為tf.reshape的返回類型是Tensor

解決方案:

對於 v1.0 之后的 tensorflow

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
row = tf.gather(data, 2)
new_row = tf.concat([row[:2], tf.constant([0]), row[3:]], axis=0)
sparse_update = tf.scatter_update(data, tf.constant(2), new_row)

對於 v1.0 之前的 tensorflow

data = tf.Variable([[1,2,3,4,5], [6,7,8,9,0], [1,2,3,4,5]])
row = tf.gather(data, 2)
new_row = tf.concat(0, [row[:2], tf.constant([0]), row[3:]])
sparse_update = tf.scatter_update(data, tf.constant(2), new_row)

這是我用來修改 Tensorflow 2 中二維張量的子集(行/列)的 function:

#note if updatedValue isVector, updatedValue should be provided in 2D format
def modifyTensorRowColumn(a, isRow, index, updatedValue, isVector):
    
    if(not isRow):
        a = tf.transpose(a)
        if(isVector):
            updatedValue = tf.transpose(updatedValue)
    
    if(index == 0):
        if(isVector):
            values = [updatedValue, a[index+1:]]
        else:
            values = [[updatedValue], a[index+1:]]
    elif(index == a.shape[0]-1):
        if(isVector):
            values = [a[:index], updatedValue]
        else:
            values = [a[:index], [updatedValue]]
    else:
        if(isVector):
            values = [a[:index], updatedValue, a[index+1:]]
        else:
            values = [a[:index], [updatedValue], a[index+1:]]
            
    a = tf.concat(axis=0, values=values)
            
    if(not isRow):
        a = tf.transpose(a)
        
    return a

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM