简体   繁体   English

如何在卷积中使用tf.where?

[英]How to use tf.where with convolutions?

I want to create a graph that splits into several other graphs at after certain point according to classification results. 我想创建一个根据分类结果在某个点之后拆分为其他几个图的图。 I thiought tf.cond or tf.where might be right to use but im not sure how. 我认为tf.condtf.where可能正确使用,但我不确定如何使用。

It is impossible to copy here all of my code but I created a small segment that illustrates the issue. 无法在此处复制我的所有代码,但我创建了一小段说明此问题。

import os
import sys
import tensorflow as tf
GPU_INDEX = 2

net_class = tf.constant([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1],[0.2, 0.4, 0.3, 0.1], [0.3, 0.2, 0.4, 0.1],[0.1, 0.3, 0.3, 0.4]]) # 3,0,1,2,3
classes = tf.argmax(net_class, axis=1)
cls_0_idx = tf.squeeze(tf.where(tf.equal(classes, 0)))
cls_3_idx = tf.squeeze(tf.where(tf.equal(classes, 3)))

cls_0 = tf.gather(params=net_class, indices=cls_0_idx)
cls_3 = tf.gather(params=net_class, indices=cls_3_idx)

params_0 = tf.constant([1.0,1,1,1])
params_3 = tf.constant([3.0,3,3,3])


output = tf.stack([tf.nn.conv1d(cls_0, params_0, 1,  padding='VALID'), tf.nn.conv1d(cls_3, params_3, 1,  padding='VALID')])

sess = tf.Session()
cls_0_idx_val = sess.run(output)

print(output)

Here I tried to extract the indexes of the input that classifies as 0 or 3 and use different variables to multiply them for the output (shared weights for each class, this is why I use the convolution). 在这里,我尝试提取分类为0或3的输入索引,并使用不同的变量将它们乘以输出(每个类的权重相同,这就是为什么要使用卷积的原因)。

I get the following error : 我收到以下错误:

ValueError: Shape must be rank 4 but is rank 2 for 'conv1d/Conv2D' (op: 'Conv2D') with input shapes: ?, [1,4].

I understand why I get the error ( because tf.where doesn't "know" its size) but the question is how do I fix it? 我知道为什么会收到错误消息(因为tf.where无法“知道”错误tf.where的大小),但问题是如何解决? (the classes are not equal and may even be empty in my "real" problem as well) (类不相等,在我的“实际”问题中甚至可能为空)

I think you should 我想你应该

  1. setup axis to 1 in tf.squeeze tf.squeeze设置axis设置为1

  2. change tf.nn.conv1d to simple multiply tf.nn.conv1d更改为简单乘法

  3. change tf.stack to tf.concat tf.stack更改为tf.concat

then you will have something like this: 那么您将获得以下内容:

net_class = tf.constant([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1],[0.2, 0.4, 0.3, 0.1], [0.3, 0.2, 0.4, 0.1],[0.1, 0.3, 0.3, 0.4]]) # 3,0,1,2,3
classes = tf.argmax(net_class, axis=1)
cls_0_idx = tf.squeeze(tf.where(tf.equal(classes, 0)), -1)
cls_3_idx = tf.squeeze(tf.where(tf.equal(classes, 3)), -1)

cls_0 = tf.gather(params=net_class, indices=cls_0_idx)
cls_3 = tf.gather(params=net_class, indices=cls_3_idx)

params_0 = tf.constant([1.0,1,1,1])
params_3 = tf.constant([3.0,3,3,3])
output = tf.concat([cls_0 * params_0, cls_3 * params_3], axis = 0)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM