簡體   English   中英

如何在卷積中使用tf.where?

[英]How to use tf.where with convolutions?

我想創建一個根據分類結果在某個點之后拆分為其他幾個圖的圖。 我認為tf.condtf.where可能正確使用,但我不確定如何使用。

無法在此處復制我的所有代碼,但我創建了一小段說明此問題。

import os
import sys
import tensorflow as tf
GPU_INDEX = 2

net_class = tf.constant([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1],[0.2, 0.4, 0.3, 0.1], [0.3, 0.2, 0.4, 0.1],[0.1, 0.3, 0.3, 0.4]]) # 3,0,1,2,3
classes = tf.argmax(net_class, axis=1)
cls_0_idx = tf.squeeze(tf.where(tf.equal(classes, 0)))
cls_3_idx = tf.squeeze(tf.where(tf.equal(classes, 3)))

cls_0 = tf.gather(params=net_class, indices=cls_0_idx)
cls_3 = tf.gather(params=net_class, indices=cls_3_idx)

params_0 = tf.constant([1.0,1,1,1])
params_3 = tf.constant([3.0,3,3,3])


output = tf.stack([tf.nn.conv1d(cls_0, params_0, 1,  padding='VALID'), tf.nn.conv1d(cls_3, params_3, 1,  padding='VALID')])

sess = tf.Session()
cls_0_idx_val = sess.run(output)

print(output)

在這里,我嘗試提取分類為0或3的輸入索引,並使用不同的變量將它們乘以輸出(每個類的權重相同,這就是為什么要使用卷積的原因)。

我收到以下錯誤:

ValueError: Shape must be rank 4 but is rank 2 for 'conv1d/Conv2D' (op: 'Conv2D') with input shapes: ?, [1,4].

我知道為什么會收到錯誤消息(因為tf.where無法“知道”錯誤tf.where的大小),但問題是如何解決? (類不相等,在我的“實際”問題中甚至可能為空)

我想你應該

  1. tf.squeeze設置axis設置為1

  2. tf.nn.conv1d更改為簡單乘法

  3. tf.stack更改為tf.concat

那么您將獲得以下內容:

net_class = tf.constant([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1],[0.2, 0.4, 0.3, 0.1], [0.3, 0.2, 0.4, 0.1],[0.1, 0.3, 0.3, 0.4]]) # 3,0,1,2,3
classes = tf.argmax(net_class, axis=1)
cls_0_idx = tf.squeeze(tf.where(tf.equal(classes, 0)), -1)
cls_3_idx = tf.squeeze(tf.where(tf.equal(classes, 3)), -1)

cls_0 = tf.gather(params=net_class, indices=cls_0_idx)
cls_3 = tf.gather(params=net_class, indices=cls_3_idx)

params_0 = tf.constant([1.0,1,1,1])
params_3 = tf.constant([3.0,3,3,3])
output = tf.concat([cls_0 * params_0, cls_3 * params_3], axis = 0)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM