簡體   English   中英

InvalidArgumentError: Conv2DCustomBackpropFilterOp 僅支持 NHWC - 修剪神經網絡

[英]InvalidArgumentError: Conv2DCustomBackpropFilterOp only supports NHWC - pruning neural network

我在使用庫 tensorflow_model_optimization 時遇到問題。 我正在開發一個代碼來修剪一個已經訓練好的神經網絡。 我從 h5 文件中導入了權重,因此我使用 tensorflor_model_optimization 來修剪我的神經網絡。 調用 fit 方法時出現此錯誤:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-26-ce9759e4dd53> in <module>
----> 1 model_for_pruning.fit_generator(base_treinamento, steps_per_epoch = 6000 /64, epochs = 5, validation_data = base_teste, validation_steps = 30, callbacks=callbacks)

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1859         use_multiprocessing=use_multiprocessing,
   1860         shuffle=shuffle,
-> 1861         initial_epoch=initial_epoch)
   1862 
   1863   def evaluate_generator(self,

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1098                 _r=1):
   1099               callbacks.on_train_batch_begin(step)
-> 1100               tmp_logs = self.train_function(iterator)
   1101               if data_handler.should_sync:
   1102                 context.async_wait()

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    826     tracing_count = self.experimental_get_tracing_count()
    827     with trace.Trace(self._name) as tm:
--> 828       result = self._call(*args, **kwds)
    829       compiler = "xla" if self._experimental_compile else "nonXla"
    830       new_tracing_count = self.experimental_get_tracing_count()

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    853       # In this case we have created variables on the first call, so we run the
    854       # defunned version which is guaranteed to never create variables.
--> 855       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    856     elif self._stateful_fn is not None:
    857       # Release the lock early so that multiple threads can perform the call

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
   2941        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   2942     return graph_function._call_flat(
-> 2943         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   2944 
   2945   @property

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1917       # No tape is watching; skip to running the function.
   1918       return self._build_call_outputs(self._inference_function.call(
-> 1919           ctx, args, cancellation_manager=cancellation_manager))
   1920     forward_backward = self._select_forward_and_backward_functions(
   1921         args,

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
    558               inputs=args,
    559               attrs=attrs,
--> 560               ctx=ctx)
    561         else:
    562           outputs = execute.execute_with_cancellation(

~\anaconda3\envs\supernova\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError:  Conv2DCustomBackpropFilterOp only supports NHWC.
     [[node gradient_tape/sequential_3/prune_low_magnitude_conv2d_18/Conv2D/Conv2DBackpropFilter (defined at <ipython-input-24-fc85f8818d30>:1) ]] [Op:__inference_train_function_10084]

Errors may have originated from an input operation.
Input Source operations connected to node gradient_tape/sequential_3/prune_low_magnitude_conv2d_18/Conv2D/Conv2DBackpropFilter:
 sequential_3/prune_low_magnitude_activation_18/Relu (defined at C:\Users\Pichau\anaconda3\envs\supernova\lib\site-packages\tensorflow_model_optimization\python\core\sparsity\keras\pruning_wrapper.py:270)

Function call stack:
train_function
enter code here

我的代碼:

from keras.models import model_from_json
from keras.preprocessing.image import ImageDataGenerator
import numpy as np

arquivo = open('model.json', 'r')
estrutura_rede = arquivo.read()
arquivo.close()
model = model_from_json(estrutura_rede)
model.load_weights('model.h5')

gerador_treinamento = ImageDataGenerator(rescale=None)
base_treinamento = gerador_treinamento.flow_from_directory('data/train', target_size = (51,51), batch_size = 64, class_mode = 'binary')

gerador_teste = ImageDataGenerator(rescale=None)
base_teste = gerador_teste.flow_from_directory('data/test', target_size = (51,51), batch_size = 64, class_mode = 'binary')

import tempfile
import tensorflow as tf
import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

batch_size = 128
epochs = 2
validation_split = 0.1

num_images = int(len(base_treinamento) * (1 - validation_split))
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

model_for_pruning.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

model_for_pruning.summary()

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit_generator(base_treinamento, steps_per_epoch = 6000 /64, epochs = 5, validation_data = base_teste, validation_steps = 30, callbacks=callbacks)

python:3.6.12

tensorflow:2.2.0

張量流模型優化:0.5.0

有人能幫我嗎?

這似乎不是修剪問題,而是 model 的數據格式問題。 看起來一個Conv2D層應該使用NHWC格式( data_format="channels_last" )。

你能分享 model 的代碼嗎?

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM