[英]How do I plot a Keras/Tensorflow subclassing API model?
I made a model that runs correctly using the Keras Subclassing API.我制作了一个 model,它使用 Keras 子类 API 正确运行。 The
model.summary()
also works correctly. model.summary()
也可以正常工作。 When trying to use tf.keras.utils.plot_model()
to visualize my model's architecture, it will just output this image:当尝试使用
tf.keras.utils.plot_model()
来可视化我的模型的架构时,它只会 output 这个图像:
This almost feels like a joke from the Keras development team.这几乎感觉像是 Keras 开发团队的玩笑。 This is the full architecture:
这是完整的架构:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.datasets import load_diabetes
import tensorflow as tf
tf.keras.backend.set_floatx('float64')
from tensorflow.keras.layers import Dense, GaussianDropout, GRU, Concatenate, Reshape
from tensorflow.keras.models import Model
X, y = load_diabetes(return_X_y=True)
data = tf.data.Dataset.from_tensor_slices((X, y)).\
shuffle(len(X)).\
map(lambda x, y: (tf.divide(x, tf.reduce_max(x)), y))
training = data.take(400).batch(8)
testing = data.skip(400).map(lambda x, y: (tf.expand_dims(x, 0), y))
class NeuralNetwork(Model):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.dense1 = Dense(16, input_shape=(10,), activation='relu', name='Dense1')
self.dense2 = Dense(32, activation='relu', name='Dense2')
self.resha1 = Reshape((1, 32))
self.gru1 = GRU(16, activation='tanh', recurrent_dropout=1e-1)
self.dense3 = Dense(64, activation='relu', name='Dense3')
self.gauss1 = GaussianDropout(5e-1)
self.conca1 = Concatenate()
self.dense4 = Dense(128, activation='relu', name='Dense4')
self.dense5 = Dense(1, name='Dense5')
def call(self, x, *args, **kwargs):
x = self.dense1(x)
x = self.dense2(x)
a = self.resha1(x)
a = self.gru1(a)
b = self.dense3(x)
b = self.gauss1(b)
x = self.conca1([a, b])
x = self.dense4(x)
x = self.dense5(x)
return x
skynet = NeuralNetwork()
skynet.build(input_shape=(None, 10))
skynet.summary()
model = tf.keras.utils.plot_model(model=skynet,
show_shapes=True, to_file='/home/nicolas/Desktop/model.png')
I've found some workaround to plot with the model sub-classing API.我找到了 plot 和 model 子类 API 的一些解决方法。 For the obvious reason Sub-Classing API doesn't support Sequential or Functional API like
model.summary()
and nice visualization using plot_model
.出于显而易见的原因,子类API 不支持像
model.summary()
这样的顺序或功能API 和使用plot_model
的漂亮可视化。 Here, I will demonstrate both.在这里,我将演示两者。
class my_model(Model):
def __init__(self, dim):
super(my_model, self).__init__()
self.Base = VGG16(input_shape=(dim), include_top = False, weights = 'imagenet')
self.GAP = L.GlobalAveragePooling2D()
self.BAT = L.BatchNormalization()
self.DROP = L.Dropout(rate=0.1)
self.DENS = L.Dense(256, activation='relu', name = 'dense_A')
self.OUT = L.Dense(1, activation='sigmoid')
def call(self, inputs):
x = self.Base(inputs)
g = self.GAP(x)
b = self.BAT(g)
d = self.DROP(b)
d = self.DENS(d)
return self.OUT(d)
# AFAIK: The most convenient method to print model.summary()
# similar to the sequential or functional API like.
def build_graph(self):
x = Input(shape=(dim))
return Model(inputs=[x], outputs=self.call(x))
dim = (124,124,3)
model = my_model((dim))
model.build((None, *dim))
model.build_graph().summary()
It will produce as follows:它将产生如下:
Layer (type) Output Shape Param #
=================================================================
input_67 (InputLayer) [(None, 124, 124, 3)] 0
_________________________________________________________________
vgg16 (Functional) (None, 3, 3, 512) 14714688
_________________________________________________________________
global_average_pooling2d_32 (None, 512) 0
_________________________________________________________________
batch_normalization_7 (Batch (None, 512) 2048
_________________________________________________________________
dropout_5 (Dropout) (None, 512) 0
_________________________________________________________________
dense_A (Dense) (None, 256) 402192
_________________________________________________________________
dense_7 (Dense) (None, 1) 785
=================================================================
Total params: 14,848,321
Trainable params: 14,847,297
Non-trainable params: 1,024
Now by using the build_graph
function, we can simply plot the whole architecture.现在通过使用
build_graph
function,我们可以简单地 plot 整个架构。
# Just showing all possible argument for newcomer.
tf.keras.utils.plot_model(
model.build_graph(), # here is the trick (for now)
to_file='model.png', dpi=96, # saving
show_shapes=True, show_layer_names=True, # show shapes and layer name
expand_nested=False # will show nested block
)
It will produce as follows: -)它将产生如下:-)
Update (04-Jan-2021): It seems this is possible;更新(2021 年 1 月 4 日):这似乎是可能的; see @M.Innat's answer .
见@M.Innat 的回答。
It could not be done because basically model sub-classing, as it is implemented in TensorFlow, is limited in features and capabilities compared to the models created using Functional/Sequential API (which are called Graph networks in TF terminology).它无法完成,因为基本上 model 子类化,因为它在 TensorFlow 中实现,与使用功能/顺序 API 网络创建的模型相比,其特性和功能受到限制(称为图形学) If you check the
plot_model
source code, you would see the following check in model_to_dot
function (which is called by plot_model
):如果您检查
plot_model
源代码,您会在model_to_dot
function(由plot_model
调用)中看到以下检查:
if not model._is_graph_network:
node = pydot.Node(str(id(model)), label=model.name)
dot.add_node(node)
return dot
As I mentioned, the sub-classed models are not graph networks and therefore only a node containing the model name would be plotted for these models (ie the same thing you observed).正如我所提到的,子分类模型不是图形网络,因此只会为这些模型绘制包含 model 名称的节点(即您观察到的相同内容)。
This has been already discussed in a Github issue and one of the developers of TensorFlow confirmed this behavior by giving the following argument:这已经在Github 问题中进行了讨论,TensorFlow 的开发人员之一通过给出以下论点证实了这种行为:
@omalleyt12 commented:
@omalleyt12 评论道:
Yes in general we can't assume anything about the structure of a subclassed Model.
是的,一般来说,我们不能假设任何关于子类 Model 的结构。 If your Model can be though of as blocks of Layers and you wish to visualize it like that, we recommend you view the Functional API
如果您的 Model 可以看作是层块并且您希望这样可视化它,我们建议您查看功能 API
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.