![](/img/trans.png)
[英][tensorflow]How do I wrap a sequence of code in to by subclassing the model class ( tf.keras.Model)
[英]How do I plot a Keras/Tensorflow subclassing API model?
我制作了一个 model,它使用 Keras 子类 API 正确运行。 model.summary()
也可以正常工作。 当尝试使用tf.keras.utils.plot_model()
来可视化我的模型的架构时,它只会 output 这个图像:
这几乎感觉像是 Keras 开发团队的玩笑。 这是完整的架构:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.datasets import load_diabetes
import tensorflow as tf
tf.keras.backend.set_floatx('float64')
from tensorflow.keras.layers import Dense, GaussianDropout, GRU, Concatenate, Reshape
from tensorflow.keras.models import Model
X, y = load_diabetes(return_X_y=True)
data = tf.data.Dataset.from_tensor_slices((X, y)).\
shuffle(len(X)).\
map(lambda x, y: (tf.divide(x, tf.reduce_max(x)), y))
training = data.take(400).batch(8)
testing = data.skip(400).map(lambda x, y: (tf.expand_dims(x, 0), y))
class NeuralNetwork(Model):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.dense1 = Dense(16, input_shape=(10,), activation='relu', name='Dense1')
self.dense2 = Dense(32, activation='relu', name='Dense2')
self.resha1 = Reshape((1, 32))
self.gru1 = GRU(16, activation='tanh', recurrent_dropout=1e-1)
self.dense3 = Dense(64, activation='relu', name='Dense3')
self.gauss1 = GaussianDropout(5e-1)
self.conca1 = Concatenate()
self.dense4 = Dense(128, activation='relu', name='Dense4')
self.dense5 = Dense(1, name='Dense5')
def call(self, x, *args, **kwargs):
x = self.dense1(x)
x = self.dense2(x)
a = self.resha1(x)
a = self.gru1(a)
b = self.dense3(x)
b = self.gauss1(b)
x = self.conca1([a, b])
x = self.dense4(x)
x = self.dense5(x)
return x
skynet = NeuralNetwork()
skynet.build(input_shape=(None, 10))
skynet.summary()
model = tf.keras.utils.plot_model(model=skynet,
show_shapes=True, to_file='/home/nicolas/Desktop/model.png')
我找到了 plot 和 model 子类 API 的一些解决方法。 出于显而易见的原因,子类API 不支持像model.summary()
这样的顺序或功能API 和使用plot_model
的漂亮可视化。 在这里,我将演示两者。
class my_model(Model):
def __init__(self, dim):
super(my_model, self).__init__()
self.Base = VGG16(input_shape=(dim), include_top = False, weights = 'imagenet')
self.GAP = L.GlobalAveragePooling2D()
self.BAT = L.BatchNormalization()
self.DROP = L.Dropout(rate=0.1)
self.DENS = L.Dense(256, activation='relu', name = 'dense_A')
self.OUT = L.Dense(1, activation='sigmoid')
def call(self, inputs):
x = self.Base(inputs)
g = self.GAP(x)
b = self.BAT(g)
d = self.DROP(b)
d = self.DENS(d)
return self.OUT(d)
# AFAIK: The most convenient method to print model.summary()
# similar to the sequential or functional API like.
def build_graph(self):
x = Input(shape=(dim))
return Model(inputs=[x], outputs=self.call(x))
dim = (124,124,3)
model = my_model((dim))
model.build((None, *dim))
model.build_graph().summary()
它将产生如下:
Layer (type) Output Shape Param #
=================================================================
input_67 (InputLayer) [(None, 124, 124, 3)] 0
_________________________________________________________________
vgg16 (Functional) (None, 3, 3, 512) 14714688
_________________________________________________________________
global_average_pooling2d_32 (None, 512) 0
_________________________________________________________________
batch_normalization_7 (Batch (None, 512) 2048
_________________________________________________________________
dropout_5 (Dropout) (None, 512) 0
_________________________________________________________________
dense_A (Dense) (None, 256) 402192
_________________________________________________________________
dense_7 (Dense) (None, 1) 785
=================================================================
Total params: 14,848,321
Trainable params: 14,847,297
Non-trainable params: 1,024
现在通过使用build_graph
function,我们可以简单地 plot 整个架构。
# Just showing all possible argument for newcomer.
tf.keras.utils.plot_model(
model.build_graph(), # here is the trick (for now)
to_file='model.png', dpi=96, # saving
show_shapes=True, show_layer_names=True, # show shapes and layer name
expand_nested=False # will show nested block
)
它将产生如下:-)
更新(2021 年 1 月 4 日):这似乎是可能的; 见@M.Innat 的回答。
它无法完成,因为基本上 model 子类化,因为它在 TensorFlow 中实现,与使用功能/顺序 API 网络创建的模型相比,其特性和功能受到限制(称为图形学) 如果您检查plot_model
源代码,您会在model_to_dot
function(由plot_model
调用)中看到以下检查:
if not model._is_graph_network:
node = pydot.Node(str(id(model)), label=model.name)
dot.add_node(node)
return dot
正如我所提到的,子分类模型不是图形网络,因此只会为这些模型绘制包含 model 名称的节点(即您观察到的相同内容)。
这已经在Github 问题中进行了讨论,TensorFlow 的开发人员之一通过给出以下论点证实了这种行为:
@omalleyt12 评论道:
是的,一般来说,我们不能假设任何关于子类 Model 的结构。 如果您的 Model 可以看作是层块并且您希望这样可视化它,我们建议您查看功能 API
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.