简体   繁体   English

将张量输入张量流图

[英]Input Tensor into tensorflow graph

I am following tutorial on simple audio recognition and currently editing the label_wav.py . 我正在学习有关简单音频识别的教程,目前正在编辑label_wav.py In original case we input wave file and the graph predicts the label (in between it calculates spectrum, mfcc's inside the graph). 在原始情况下,我们输入wave文件,然后图形预测标签(在它们之间计算频谱,即图形内部的mfcc)。 Now i am looking to input mfcc's directly rather than inputting the wave file. 现在,我希望直接输入mfcc而不是输入wave文件。 Run the graph by inputting the mfcc tensor. 通过输入mfcc张量运行图形

# mfccs:  Tensor("strided_slice:0", shape=(1, 98, 40), dtype=float32)
mfcc_input_layer_name = 'Reshape:0'
with tf.Session() as sess:
    predictions, = sess.run(softmax_tensor, {mfcc_input_layer_name: mfcc})

After a bit of googling, i found some discussion in git and created a session_handle . 经过一番谷歌搜索,我在git中找到了一些讨论,并创建了一个session_handle

# mfccs:  Tensor("strided_slice:0", shape=(1, 98, 40), dtype=float32)
mfcc_input_layer_name = 'Reshape:0'
with tf.Session() as sess:
      h = tf.get_session_handle(mfccs)
      h = sess.run(h)
      predictions, = sess.run(softmax_tensor, {mfcc_input_layer_name: h})

The code is working as expected but I am wondering if there could be a better way of dealing with the tensor rather than creating the handle and then passing it? 代码按预期工作,但是我想知道是否有更好的方法来处理张量而不是创建句柄然后传递它?

I suppose you want to replace an intermediate Tensor with a value by feed_dict. 我想您想用feed_dict用值替换中间张量。 If you have a Tensor object, you can replace it by feed_dict as the following 如果您有一个Tensor对象,则可以将其替换为feed_dict,如下所示

a = tf.constant(3, name="a")
b = tf.constant(4, name="b")
c = tf.add(a, b, name="c")
d = c * 3

with tf.Session() as sess:
    print(sess.run(d))    
    print(sess.run(d, feed_dict={c: 2}))

Even though you don't have the Tensor object, you can get it by get_tensor_by_name 即使您没有Tensor对象,也可以通过get_tensor_by_name来获取它

a = tf.constant(3, name="a")
b = tf.constant(4, name="b")
c = tf.add(a, b, name="c")
d = c * 3

with tf.Session() as sess:
    c_tensor = tf.get_default_graph().get_tensor_by_name("c:0")
    print(sess.run(d, feed_dict={c_tensor: 2}))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM