简体   繁体   English

我如何 plot 一个 Keras/Tensorflow 子类 API model?

[英]How do I plot a Keras/Tensorflow subclassing API model?

I made a model that runs correctly using the Keras Subclassing API.我制作了一个 model,它使用 Keras 子类 API 正确运行。 The model.summary() also works correctly. model.summary()也可以正常工作。 When trying to use tf.keras.utils.plot_model() to visualize my model's architecture, it will just output this image:当尝试使用tf.keras.utils.plot_model()来可视化我的模型的架构时,它只会 output 这个图像:

在此处输入图像描述

This almost feels like a joke from the Keras development team.这几乎感觉像是 Keras 开发团队的玩笑。 This is the full architecture:这是完整的架构:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.datasets import load_diabetes
import tensorflow as tf
tf.keras.backend.set_floatx('float64')
from tensorflow.keras.layers import Dense, GaussianDropout, GRU, Concatenate, Reshape
from tensorflow.keras.models import Model

X, y = load_diabetes(return_X_y=True)

data = tf.data.Dataset.from_tensor_slices((X, y)).\
    shuffle(len(X)).\
    map(lambda x, y: (tf.divide(x, tf.reduce_max(x)), y))

training = data.take(400).batch(8)
testing = data.skip(400).map(lambda x, y: (tf.expand_dims(x, 0), y))

class NeuralNetwork(Model):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.dense1 = Dense(16, input_shape=(10,), activation='relu', name='Dense1')
        self.dense2 = Dense(32, activation='relu', name='Dense2')
        self.resha1 = Reshape((1, 32))
        self.gru1 = GRU(16, activation='tanh', recurrent_dropout=1e-1)
        self.dense3 = Dense(64, activation='relu', name='Dense3')
        self.gauss1 = GaussianDropout(5e-1)
        self.conca1 = Concatenate()
        self.dense4 = Dense(128, activation='relu', name='Dense4')
        self.dense5 = Dense(1, name='Dense5')

    def call(self, x, *args, **kwargs):
        x = self.dense1(x)
        x = self.dense2(x)
        a = self.resha1(x)
        a = self.gru1(a)
        b = self.dense3(x)
        b = self.gauss1(b)
        x = self.conca1([a, b])
        x = self.dense4(x)
        x = self.dense5(x)
        return x


skynet = NeuralNetwork()
skynet.build(input_shape=(None, 10))
skynet.summary()

model = tf.keras.utils.plot_model(model=skynet,
         show_shapes=True, to_file='/home/nicolas/Desktop/model.png')

I've found some workaround to plot with the model sub-classing API.我找到了 plot 和 model 子类 API 的一些解决方法。 For the obvious reason Sub-Classing API doesn't support Sequential or Functional API like model.summary() and nice visualization using plot_model .出于显而易见的原因,子类API 不支持像model.summary()这样的顺序或功能API 和使用plot_model的漂亮可视化。 Here, I will demonstrate both.在这里,我将演示两者。

class my_model(Model):
    def __init__(self, dim):
        super(my_model, self).__init__()
        self.Base  = VGG16(input_shape=(dim), include_top = False, weights = 'imagenet')
        self.GAP   = L.GlobalAveragePooling2D()
        self.BAT   = L.BatchNormalization()
        self.DROP  = L.Dropout(rate=0.1)
        self.DENS  = L.Dense(256, activation='relu', name = 'dense_A')
        self.OUT   = L.Dense(1, activation='sigmoid')
    
    def call(self, inputs):
        x  = self.Base(inputs)
        g  = self.GAP(x)
        b  = self.BAT(g)
        d  = self.DROP(b)
        d  = self.DENS(d)
        return self.OUT(d)
    
    # AFAIK: The most convenient method to print model.summary() 
    # similar to the sequential or functional API like.
    def build_graph(self):
        x = Input(shape=(dim))
        return Model(inputs=[x], outputs=self.call(x))

dim = (124,124,3)
model = my_model((dim))
model.build((None, *dim))
model.build_graph().summary()

It will produce as follows:它将产生如下:

Layer (type)                 Output Shape              Param #   
=================================================================
input_67 (InputLayer)        [(None, 124, 124, 3)]     0         
_________________________________________________________________
vgg16 (Functional)           (None, 3, 3, 512)         14714688  
_________________________________________________________________
global_average_pooling2d_32  (None, 512)               0         
_________________________________________________________________
batch_normalization_7 (Batch (None, 512)               2048      
_________________________________________________________________
dropout_5 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_A (Dense)              (None, 256)               402192    
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 785       
=================================================================
Total params: 14,848,321
Trainable params: 14,847,297
Non-trainable params: 1,024

Now by using the build_graph function, we can simply plot the whole architecture.现在通过使用build_graph function,我们可以简单地 plot 整个架构。

# Just showing all possible argument for newcomer.  
tf.keras.utils.plot_model(
    model.build_graph(),                      # here is the trick (for now)
    to_file='model.png', dpi=96,              # saving  
    show_shapes=True, show_layer_names=True,  # show shapes and layer name
    expand_nested=False                       # will show nested block
)

It will produce as follows: -)它将产生如下:-)

一个

Update (04-Jan-2021): It seems this is possible;更新(2021 年 1 月 4 日):这似乎是可能的; see @M.Innat's answer .见@M.Innat 的回答


It could not be done because basically model sub-classing, as it is implemented in TensorFlow, is limited in features and capabilities compared to the models created using Functional/Sequential API (which are called Graph networks in TF terminology).它无法完成,因为基本上 model 子类化,因为它在 TensorFlow 中实现,与使用功能/顺序 API 网络创建的模型相比,其特性和功能受到限制(称为图形学) If you check the plot_model source code, you would see the following check in model_to_dot function (which is called by plot_model ):如果您检查plot_model源代码,您会在model_to_dot function(由plot_model调用)中看到以下检查

if not model._is_graph_network:
  node = pydot.Node(str(id(model)), label=model.name)
  dot.add_node(node)
  return dot

As I mentioned, the sub-classed models are not graph networks and therefore only a node containing the model name would be plotted for these models (ie the same thing you observed).正如我所提到的,子分类模型不是图形网络,因此只会为这些模型绘制包含 model 名称的节点(即您观察到的相同内容)。

This has been already discussed in a Github issue and one of the developers of TensorFlow confirmed this behavior by giving the following argument:这已经在Github 问题中进行了讨论,TensorFlow 的开发人员之一通过给出以下论点证实了这种行为:

@omalleyt12 commented: @omalleyt12 评论道:

Yes in general we can't assume anything about the structure of a subclassed Model.是的,一般来说,我们不能假设任何关于子类 Model 的结构。 If your Model can be though of as blocks of Layers and you wish to visualize it like that, we recommend you view the Functional API如果您的 Model 可以看作是层块并且您希望这样可视化它,我们建议您查看功能 API

Another workaround: convert the savemodel format model to onnx using tf2onnx , then use netron to view the model architecture.另一种解决方法:使用 tf2onnx 将 savemodel 格式 model 转换为onnx ,然后使用netron查看 model 架构。

Here is part of the model in netron:这是 netron 中 model 的一部分:图片

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 [tensorflow]如何通过对模型类(tf.keras.Model)进行子类化来将一系列代码包装在其中 - [tensorflow]How do I wrap a sequence of code in to by subclassing the model class ( tf.keras.Model) 如何使用 tf.keras.Model - TensorFlow 2.0 - 子类化 API 保存和恢复模式的权重 - How to save and restore a mode's weights with tf.keras.Model - TensorFlow 2.0 - Subclassing API AttributeError:层 mnist_model_35 没有入站节点。 Tensorflow keras 子类化 API - AttributeError: Layer mnist_model_35 has no inbound nodes. Tensorflow keras subclassing API 如何使用 tensorflow 损失 function 与 keras Z20F35E630DAF44DBDFA4C3F68F539? - How do I use a tensorflow loss function with a keras model? Tensorflow 2 Keras 嵌套 Model 子类化 - 总参数为零 - Tensorflow 2 Keras Nested Model Subclassing - Total parameters zero 如何修复keras子类模型中的批量大小? - How to fix the batch size in keras subclassing model? TensorFlow model 到 Keras 功能 API? - TensorFlow model to Keras functional API? subclassing of Model class and model functional API give different results in tensorflow - subclassing of Model class and model functional API give different results in tensorflow 使用 tensorflow 2 进行模型子类化中的 ValueError - ValueError in model subclassing with tensorflow 2 Tensorflow 2.0 Model 子类化 - Tensorflow 2.0 Model subclassing
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM