简体   繁体   中英

How to handle 0 batch size in tf.layers.max_pooling2d Tensorflow?

I am using Tensorflow fold to write a model, and Tensorflow fold often times has 0 batch sizes ( Error on github ). This causes issues for certain Tensorflow operations, saying an error like so: F tensorflow/stream_executor/cuda/cuda_dnn.cc:466] could not set cudnn tensor descriptor: CUDNN_STATUS_BAD_PARAM

However, it is possible to solve this by writing a custom gradient operation, like described here and below.

# TensorFlow Fold can generate zero-size batch for conv layer
# which will crash cuDNN on backward pass. So use this
# for arbitrary convolution in modules to avoid the crash.
def _conv_safe(inputs, filters, kernel_size, strides, activation):
    g = tf.get_default_graph()
    with g.gradient_override_map({'Conv2D': 'Conv2D_handle_empty_batch'}):
        return tf.layers.conv2d(inputs=inputs, filters=filters, kernel_size=kernel_size,strides=strides, activation=activation)

@tf.RegisterGradient('Conv2D_handle_empty_batch')
def _Conv2DGrad(op, grad):
    with tf.device('/cpu:0'):
        return [tf.nn.conv2d_backprop_input(
                tf.shape(op.inputs[0]), op.inputs[1], grad, op.get_attr('strides'),
                op.get_attr('padding'), op.get_attr('use_cudnn_on_gpu'),
                op.get_attr('data_format')),
                tf.nn.conv2d_backprop_filter(op.inputs[0],
                                             tf.shape(op.inputs[1]), grad,
                                             op.get_attr('strides'),
                                             op.get_attr('padding'),
                                             op.get_attr('use_cudnn_on_gpu'),
                                             op.get_attr('data_format'))]

I am now wondering how I can do a similar thing to avoid this crash when using the tf.layers.max_pooling2d operation, or any other form of max pooling. You can see in the example for tf.layers.conv2d , we are able to get around it by custom implementing the gradient to handle the 0 batch size. How can I do this for tf.layers.max_pooling2d ?

Note: I am using Tensorflow 1.0 since that is what is supported by Tensorflow Fold.

Thanks

I think we can do it like this:

from tensorflow.python.ops import gen_nn_ops

def max_pooling_zero_batch(inputs, pool_size, strides, name):

    g = tf.get_default_graph()
    with g.gradient_override_map({'MaxPool': 'MaxPool_handle_empty_batch'}):  
        return tf.layers.max_pooling2d(inputs=inputs, pool_size=pool_size, strides=strides, name=name)

@tf.RegisterGradient("MaxPool_handle_empty_batch")
def _MaxPoolGrad(op, grad):
    with tf.device('/cpu:0'):
        return gen_nn_ops._max_pool_grad(op.inputs[0], op.outputs[0], grad, op.get_attr("ksize"), op.get_attr("strides"), padding=op.get_attr("padding"), data_format=op.get_attr("data_format"))

It seems to work with 0 batch size.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM