繁体   English   中英

如何解决 tensorflow Conv2DBackpropFilter 错误?

[英]How to solve tensorflow Conv2DBackpropFilter error?

我一直在尝试使用tf-gpu没有成功。 我从 anaconda 安装了 cuda(不确定这是问题还是编码问题)。 该代码在不使用 gpu 和安装 cuda 的情况下正常工作。 但是在我安装tf gpu我收到此错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError:  Conv2DBackpropFilter: input depth must be evenly divisible by filter depth
     [[{{node gradient_tape/sequential/conv2d_10/Conv2D/Conv2DBackpropFilter4}}]] [Op:__inference_train_function_1390]
Function call stack:
train_function

我的代码:

import tensorflow as tf
import matplotlib.pyplot as plt
import os
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
from tensorflow.keras.optimizers import SGD,Adam
from tensorflow.keras.models import load_model

physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)
train = ImageDataGenerator(rescale=1 / 255, rotation_range=20,
      width_shift_range=0.2,
      height_shift_range=0.2,
      shear_range=0.2,
      zoom_range=0.2,
      horizontal_flip=True,
      vertical_flip=True,
      fill_mode='nearest',
      brightness_range=(0.1,0.9))
validation = ImageDataGenerator(rescale=1 / 255)
test = ImageDataGenerator(rescale=1 / 255)

train_dataset = train.flow_from_directory('/raw-img/training', target_size=(200,200), batch_size=1,
                                          class_mode='categorical')

validation_dataset = validation.flow_from_directory('/raw-img/validation', target_size=(200,200), batch_size=1,
                                               class_mode='categorical')
test_dataset = test.flow_from_directory('/raw-img/testing', target_size=(200,200), batch_size=1,
                                               class_mode='categorical')
tf.config.experimental.enable_mlir_graph_optimization()

model = tf.keras.models.Sequential([
                                    tf.keras.layers.Conv2D(filters=32, kernel_size=(3,3), input_shape=(200,200,3),padding='same'),
                                    tf.keras.layers.LeakyReLU( alpha=0.3),
                                    tf.keras.layers.MaxPool2D(pool_size=(2, 2),padding='same'),
                                    tf.keras.layers.Dropout(rate=.2),
                                    #
                                    tf.keras.layers.Conv2D(filters=32, kernel_size=(3,3), padding='same'),
    tf.keras.layers.LeakyReLU( alpha=0.3),
                                     tf.keras.layers.Conv2D(filters=32,kernel_size= (3,3),padding='same'),
    tf.keras.layers.LeakyReLU( alpha=0.3),
                                    tf.keras.layers.MaxPool2D(pool_size=(2, 2),padding='same'),
    tf.keras.layers.Dropout(rate=.25),
                                    #
                                    tf.keras.layers.Conv2D(filters=64,kernel_size=(3,3),padding='same'),
    tf.keras.layers.LeakyReLU( alpha=0.3),
                                    tf.keras.layers.Conv2D(filters=64, kernel_size=(3,3),padding='same'),
    tf.keras.layers.LeakyReLU( alpha=0.3),
                                    tf.keras.layers.MaxPool2D(pool_size=(2, 2),padding='same'),
    tf.keras.layers.Dropout(rate=.25),
                                    #
                                    tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3),padding='same'),
    tf.keras.layers.LeakyReLU( alpha=0.3),
                                    tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3),padding='same'),
    tf.keras.layers.LeakyReLU( alpha=0.3),
                                    tf.keras.layers.MaxPool2D(pool_size=(2, 2),padding='same'),
    tf.keras.layers.Dropout(rate=.25),
                                    #
                                    tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3),padding='same'),
    tf.keras.layers.LeakyReLU( alpha=0.3),
                                    tf.keras.layers.Conv2D(filters=128, kernel_size=(3,3),padding='same'),
    tf.keras.layers.LeakyReLU( alpha=0.3),
                                    tf.keras.layers.MaxPool2D(pool_size=(2, 2),padding='same'),
    tf.keras.layers.Dropout(rate=.25),
                                    #

                                    tf.keras.layers.Conv2D(filters=256, kernel_size=(3,3),padding='same'),
    tf.keras.layers.LeakyReLU( alpha=0.3),
                                    tf.keras.layers.Conv2D(filters=256, kernel_size=(3,3),padding='same'),
    tf.keras.layers.LeakyReLU( alpha=0.3),
                                    tf.keras.layers.MaxPool2D(pool_size=(2, 2),padding='same'),
    tf.keras.layers.Dropout(rate=.2),
                                    #
                                    tf.keras.layers.Flatten(),
                                    #
                                    tf.keras.layers.Dense(units=1024),
    tf.keras.layers.LeakyReLU( alpha=0.3),
                                    tf.keras.layers.Dense(units=512),
    tf.keras.layers.LeakyReLU( alpha=0.3),
                                    tf.keras.layers.Dense(units=10, activation='softmax'),

                                    ])
print(model.summary())
rlronp=tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',patience=3,verbose=1, factor=0.7)
es=tf.keras.callbacks.EarlyStopping(monitor="val_loss",patience=15,verbose=1,
                                    restore_best_weights=True)
model.compile(loss='categorical_crossentropy', optimizer=SGD(learning_rate=0.003), metrics=['accuracy'])
model_fit = model.fit(train_dataset, epochs=100, batch_size=32, validation_data=validation_dataset,steps_per_epoch=len(train_dataset),validation_steps=len(validation_dataset),callbacks=[rlronp,es])

以下是我的 conda 环境中包含的软件包:

_tflow_select             2.1.0                       gpu
abseil-cpp                20210324.2           hd77b12b_0
absl-py                   0.13.0           py39haa95532_0
aiohttp                   3.7.4            py39h2bbff1b_1
argon2-cffi               20.1.0           py39h2bbff1b_1
astor                     0.8.1            py39haa95532_0
astunparse                1.6.3                      py_0
async-timeout             3.0.1            py39haa95532_0
async_generator           1.10               pyhd3eb1b0_0
attrs                     21.2.0             pyhd3eb1b0_0
backcall                  0.2.0              pyhd3eb1b0_0
blas                      1.0                         mkl
bleach                    3.3.1              pyhd3eb1b0_0
blinker                   1.4              py39haa95532_0
brotlipy                  0.7.0           py39h2bbff1b_1003
ca-certificates           2021.7.5             haa95532_1
cached-property           1.5.2                      py_0
cachetools                4.2.2              pyhd3eb1b0_0
certifi                   2021.5.30        py39haa95532_0
cffi                      1.14.6           py39h2bbff1b_0
chardet                   3.0.4           py39haa95532_1003
click                     8.0.1              pyhd3eb1b0_0
colorama                  0.4.4              pyhd3eb1b0_0
coverage                  5.5              py39h2bbff1b_2
cryptography              3.4.7            py39h71e12ea_0
cudatoolkit               11.3.1               h59b6b97_2
cudnn                     8.2.1                cuda11.3_0
cycler                    0.10.0           py39haa95532_0
cython                    0.29.24          py39hd77b12b_0
decorator                 5.0.9              pyhd3eb1b0_0
defusedxml                0.7.1              pyhd3eb1b0_0
entrypoints               0.3              py39haa95532_0
flatbuffers               2.0.0                h6c2663c_0
freetype                  2.10.4               hd328e21_0
gast                      0.4.0                      py_0
giflib                    5.2.1                h62dcd97_0
google-auth               1.33.0             pyhd3eb1b0_0
google-auth-oauthlib      0.4.1                      py_2
google-pasta              0.2.0                      py_0
grpcio                    1.36.1           py39hc60d5dd_1
h5py                      3.2.1            py39h3de5c98_0
hdf5                      1.10.6               h7ebc959_0
icc_rt                    2019.0.0             h0cc432a_1
icu                       68.1                 h6c2663c_0
idna                      2.10               pyhd3eb1b0_0
importlib-metadata        3.10.0           py39haa95532_0
importlib_metadata        3.10.0               hd3eb1b0_0
intel-openmp              2021.3.0          haa95532_3372
ipykernel                 5.3.4            py39h7b7c402_0
ipython                   7.25.0           py39h832f523_1    conda-forge
ipython_genutils          0.2.0              pyhd3eb1b0_1
jedi                      0.18.0           py39haa95532_1
jinja2                    3.0.1              pyhd3eb1b0_0
jpeg                      9d                   h8ffe710_0    conda-forge
jsonschema                3.2.0                      py_2
jupyter_client            6.1.12             pyhd3eb1b0_0
jupyter_core              4.7.1            py39haa95532_0
jupyterlab_pygments       0.1.2                      py_0
keras-preprocessing       1.1.2              pyhd3eb1b0_0
kiwisolver                1.3.1            py39hd77b12b_0
krb5                      1.18.2               hc04afaa_0
libclang                  11.1.0          default_h5c34c98_1    conda-forge
libcurl                   7.71.1               h2a8f88b_1
libpng                    1.6.37               h2a8f88b_0
libprotobuf               3.14.0               h23ce68f_0
libsodium                 1.0.18               h62dcd97_0
libssh2                   1.9.0                h7a1dbc1_1
libtiff                   4.2.0                hd0e1b90_0
lz4-c                     1.9.3                h2bbff1b_0
m2w64-gcc-libgfortran     5.3.0                         6
m2w64-gcc-libs            5.3.0                         7
m2w64-gcc-libs-core       5.3.0                         7
m2w64-gmp                 6.1.0                         2
m2w64-libwinpthread-git   5.0.0.4634.697f757               2
markdown                  3.3.4            py39haa95532_0
markupsafe                2.0.1            py39h2bbff1b_0
matplotlib                3.3.4            py39haa95532_0
matplotlib-base           3.3.4            py39h49ac443_0
matplotlib-inline         0.1.2              pyhd8ed1ab_2    conda-forge
mistune                   0.8.4           py39h2bbff1b_1000
mkl                       2021.3.0           haa95532_524
mkl-service               2.4.0            py39h2bbff1b_0
mkl_fft                   1.3.0            py39h277e83a_2
mkl_random                1.2.2            py39hf11a4ad_0
msys2-conda-epoch         20160418                      1
multidict                 5.1.0            py39h2bbff1b_2
nbclient                  0.5.3              pyhd3eb1b0_0
nbconvert                 6.1.0            py39haa95532_0
nbformat                  5.1.3              pyhd3eb1b0_0
nest-asyncio              1.5.1              pyhd3eb1b0_0
notebook                  6.4.0            py39haa95532_0
numpy                     1.20.3           py39ha4e8547_0
numpy-base                1.20.3           py39hc2deb75_0
oauthlib                  3.1.1              pyhd3eb1b0_0
olefile                   0.46                       py_0
openssl                   1.1.1k               h2bbff1b_0
opt_einsum                3.3.0              pyhd3eb1b0_1
packaging                 21.0               pyhd3eb1b0_0
pandocfilters             1.4.3            py39haa95532_1
parso                     0.8.2              pyhd3eb1b0_0
pickleshare               0.7.5           pyhd3eb1b0_1003
pillow                    8.3.1            py39h4fa10fc_0
pip                       21.1.3           py39haa95532_0
powershell_shortcut       0.0.1                         3
prometheus_client         0.11.0             pyhd3eb1b0_0
prompt-toolkit            3.0.17             pyh06a4308_0
protobuf                  3.14.0           py39hd77b12b_1
pyasn1                    0.4.8                      py_0
pyasn1-modules            0.2.8                      py_0
pycparser                 2.20                       py_2
pygments                  2.9.0              pyhd3eb1b0_0
pyjwt                     2.1.0            py39haa95532_0
pyopenssl                 20.0.1             pyhd3eb1b0_1
pyparsing                 2.4.7              pyhd3eb1b0_0
pyqt                      5.12.3           py39hcbf5309_7    conda-forge
pyqt-impl                 5.12.3           py39h415ef7b_7    conda-forge
pyqt5-sip                 4.19.18          py39h415ef7b_7    conda-forge
pyqtchart                 5.12             py39h415ef7b_7    conda-forge
pyqtwebengine             5.12.1           py39h415ef7b_7    conda-forge
pyreadline                2.1              py39haa95532_1
pyrsistent                0.18.0           py39h2bbff1b_0
pysocks                   1.7.1            py39haa95532_0
python                    3.9.5                h6244533_3
python-dateutil           2.8.2              pyhd3eb1b0_0
python-flatbuffers        1.12               pyhd3eb1b0_0
python_abi                3.9                      2_cp39    conda-forge
pywin32                   228              py39he774522_0
pywinpty                  0.5.7            py39haa95532_0
pyzmq                     20.0.0           py39hd77b12b_1
qt                        5.12.9               h5909a2a_4    conda-forge
requests                  2.25.1             pyhd3eb1b0_0
requests-oauthlib         1.3.0                      py_0
rsa                       4.7.2              pyhd3eb1b0_1
scipy                     1.6.2            py39h66253e8_1
send2trash                1.5.0              pyhd3eb1b0_1
setuptools                52.0.0           py39haa95532_0
six                       1.16.0             pyhd3eb1b0_0
snappy                    1.1.8                h33f27b4_0
sqlite                    3.36.0               h2bbff1b_0
tensorboard               2.5.0                      py_0
tensorboard-plugin-wit    1.6.0                      py_0
tensorflow                2.5.0           gpu_py39h7dc34a2_0
tensorflow-base           2.5.0           gpu_py39hb3da07e_0
tensorflow-estimator      2.5.0              pyh7b7c402_0
tensorflow-gpu            2.5.0                h17022bd_0
termcolor                 1.1.0            py39haa95532_1
terminado                 0.9.4            py39haa95532_0
testpath                  0.5.0              pyhd3eb1b0_0
tk                        8.6.10               he774522_0
tornado                   6.1              py39h2bbff1b_0
traitlets                 5.0.5              pyhd3eb1b0_0
typing-extensions         3.10.0.0             hd3eb1b0_0
typing_extensions         3.10.0.0           pyh06a4308_0
tzdata                    2021a                h52ac0ba_0
urllib3                   1.26.6             pyhd3eb1b0_1
vc                        14.2                 h21ff451_1
vs2015_runtime            14.27.29016          h5e58377_2
wcwidth                   0.2.5                      py_0
webencodings              0.5.1            py39haa95532_1
werkzeug                  1.0.1              pyhd3eb1b0_0
wheel                     0.35.1             pyhd3eb1b0_0
win_inet_pton             1.1.0            py39haa95532_0
wincertstore              0.2              py39h2bbff1b_0
winpty                    0.4.3                         4
wrapt                     1.12.1           py39h196d8e1_1
xz                        5.2.5                h62dcd97_0
yarl                      1.6.3            py39h2bbff1b_0
zeromq                    4.3.3                ha925a31_3
zipp                      3.5.0              pyhd3eb1b0_0
zlib                      1.2.11               h62dcd97_4
zstd                      1.4.9                h19a0ad4_0

Tensorflow ( https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_mlir_graph_optimization?hl=ru ):

“注意:基于 MLIR 的 TensorFlow 编译器正在积极开发中,缺少功能,请不要使用。此 API 仅用于开发和测试。”

我不确定问题是什么,但您应该尝试删除该行:

tf.config.experimental.enable_mlir_graph_optimization()

看看它是如何运行的。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM