簡體   English   中英

如何列出節點所依賴的所有Tensorflow變量?

[英]How can I list all Tensorflow variables a node depends on?

如何列出節點所依賴的所有Tensorflow變量/常量/占位符?

示例1(添加常量):

import tensorflow as tf

a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')

sess = tf.Session()
print(sess.run([d, e]))

我想有一個函數list_dependencies() ,如:

  • list_dependencies(d)返回['a', 'b']
  • list_dependencies(e)返回['a', 'b', 'c']

示例2(占位符和權重矩陣之間的矩陣乘法,然后添加偏差向量):

tf.set_random_seed(1)
input_size  = 5
output_size = 3
input       = tf.placeholder(tf.float32, shape=[1, input_size], name='input')
W           = tf.get_variable(
                "W",
                shape=[input_size, output_size],
                initializer=tf.contrib.layers.xavier_initializer())
b           = tf.get_variable(
                "b",
                shape=[output_size],
                initializer=tf.constant_initializer(2))
output      = tf.matmul(input, W, name="output")
output_bias = tf.nn.xw_plus_b(input, W, b, name="output_bias")

sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run([output,output_bias], feed_dict={input: [[2]*input_size]}))

我想有一個函數list_dependencies() ,如:

  • list_dependencies(output)返回['W', 'input']
  • list_dependencies(output_bias)返回['W', 'b', 'input']

以下是我用於此的實用程序(來自https://github.com/yaroslavvb/stuff/blob/master/linearize/linearize.py

# computation flows from parents to children

def parents(op):
  return set(input.op for input in op.inputs)

def children(op):
  return set(op for out in op.outputs for op in out.consumers())

def get_graph():
  """Creates dictionary {node: {child1, child2, ..},..} for current
  TensorFlow graph. Result is compatible with networkx/toposort"""

  ops = tf.get_default_graph().get_operations()
  return {op: children(op) for op in ops}


def print_tf_graph(graph):
  """Prints tensorflow graph in dictionary form."""
  for node in graph:
    for child in graph[node]:
      print("%s -> %s" % (node.name, child.name))

這些功能適用於操作。 要獲得產生張量t的op,請使用t.op 要獲得op op生成的張量,請使用op.outputs

Yaroslav Bulatov的答案很棒,我只想添加一個使用Yaroslav的get_graph()children()方法的繪圖函數:

import matplotlib.pyplot as plt
import networkx as nx
def plot_graph(G):
    '''Plot a DAG using NetworkX'''        
    def mapping(node):
        return node.name
    G = nx.DiGraph(G)
    nx.relabel_nodes(G, mapping, copy=False)
    nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True)
    plt.show()

plot_graph(get_graph())

從問題中繪制示例1:

import matplotlib.pyplot as plt
import networkx as nx
import tensorflow as tf

def children(op):
  return set(op for out in op.outputs for op in out.consumers())

def get_graph():
  """Creates dictionary {node: {child1, child2, ..},..} for current
  TensorFlow graph. Result is compatible with networkx/toposort"""
  print('get_graph')
  ops = tf.get_default_graph().get_operations()
  return {op: children(op) for op in ops}

def plot_graph(G):
    '''Plot a DAG using NetworkX'''        
    def mapping(node):
        return node.name
    G = nx.DiGraph(G)
    nx.relabel_nodes(G, mapping, copy=False)
    nx.draw(G, cmap = plt.get_cmap('jet'), with_labels = True)
    plt.show()

a = tf.constant(1, name = 'a')
b = tf.constant(3, name = 'b')
c = tf.constant(9, name = 'c')
d = tf.add(a, b, name='d')
e = tf.add(d, c, name='e')

sess = tf.Session()
print(sess.run([d, e]))
plot_graph(get_graph())

輸出:

在此輸入圖像描述

從問題中繪制示例2:

在此輸入圖像描述

如果您使用Microsoft Windows,您可能會遇到此問題: Python錯誤(ValueError:_getfullpathname:嵌入的空字符) ,在這種情況下,您需要修補matplotlib,因為鏈接說明。

這些都是很好的答案,我將添加一個簡單的方法,以不易讀取的格式生成依賴項,但對於快速調試非常有用。

tf.get_default_graph().as_graph_def()

在圖表中生成操作的打印,如下所示的簡單字典。 每個OP都很容易通過名稱及其屬性和輸入來識別,從而允許您遵循依賴關系。

import tensorflow as tf

a = tf.placeholder(tf.float32, name='placeholder_1')
b = tf.placeholder(tf.float32, name='placeholder_2')
c = a + b

tf.get_default_graph().as_graph_def()

Out[14]: 
node {
  name: "placeholder_1"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "placeholder_2"
  op: "Placeholder"
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "shape"
    value {
      shape {
        unknown_rank: true
      }
    }
  }
}
node {
  name: "add"
  op: "Add"
  input: "placeholder_1"
  input: "placeholder_2"
  attr {
    key: "T"
    value {
      type: DT_FLOAT
    }
  }
}
versions {
  producer: 27
}

在某些情況下,人們可能想要找到連接到“輸出”張量的所有“輸入”變量,例如圖的丟失。 為此目的,以下代碼剪切可能是有用的(受上面的代碼啟發):

def findVars(atensor):
    allinputs=atensor.op.inputs
    if len(allinputs)==0:
        if atensor.op.type == 'VariableV2' or atensor.op.type == 'Variable':
            return set([atensor.op])
    a=set()
    for t in allinputs:
        a=a | findVars(t)
    return a

這可以在調試中用於找出圖中的連接缺失的位置。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM