I am using Tensorflow fold to write a model, and Tensorflow fold often times has 0 batch sizes ( Error on github ). This causes issues for certain Tensorflow operations, saying an error like so: F tensorflow/stream_executor/cuda/cuda_dnn.cc:466] could not set cudnn tensor descriptor: CUDNN_STATUS_BAD_PARAM
However, it is possible to solve this by writing a custom gradient operation, like described here and below.
# TensorFlow Fold can generate zero-size batch for conv layer
# which will crash cuDNN on backward pass. So use this
# for arbitrary convolution in modules to avoid the crash.
def _conv_safe(inputs, filters, kernel_size, strides, activation):
g = tf.get_default_graph()
with g.gradient_override_map({'Conv2D': 'Conv2D_handle_empty_batch'}):
return tf.layers.conv2d(inputs=inputs, filters=filters, kernel_size=kernel_size,strides=strides, activation=activation)
@tf.RegisterGradient('Conv2D_handle_empty_batch')
def _Conv2DGrad(op, grad):
with tf.device('/cpu:0'):
return [tf.nn.conv2d_backprop_input(
tf.shape(op.inputs[0]), op.inputs[1], grad, op.get_attr('strides'),
op.get_attr('padding'), op.get_attr('use_cudnn_on_gpu'),
op.get_attr('data_format')),
tf.nn.conv2d_backprop_filter(op.inputs[0],
tf.shape(op.inputs[1]), grad,
op.get_attr('strides'),
op.get_attr('padding'),
op.get_attr('use_cudnn_on_gpu'),
op.get_attr('data_format'))]
I am now wondering how I can do a similar thing to avoid this crash when using the tf.layers.max_pooling2d
operation, or any other form of max pooling. You can see in the example for tf.layers.conv2d
, we are able to get around it by custom implementing the gradient to handle the 0 batch size. How can I do this for tf.layers.max_pooling2d
?
Note: I am using Tensorflow 1.0 since that is what is supported by Tensorflow Fold.
Thanks
I think we can do it like this:
from tensorflow.python.ops import gen_nn_ops
def max_pooling_zero_batch(inputs, pool_size, strides, name):
g = tf.get_default_graph()
with g.gradient_override_map({'MaxPool': 'MaxPool_handle_empty_batch'}):
return tf.layers.max_pooling2d(inputs=inputs, pool_size=pool_size, strides=strides, name=name)
@tf.RegisterGradient("MaxPool_handle_empty_batch")
def _MaxPoolGrad(op, grad):
with tf.device('/cpu:0'):
return gen_nn_ops._max_pool_grad(op.inputs[0], op.outputs[0], grad, op.get_attr("ksize"), op.get_attr("strides"), padding=op.get_attr("padding"), data_format=op.get_attr("data_format"))
It seems to work with 0 batch size.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.