Tensorflow unpooling after max_pool_with_argmax using indices

While trying to implement U-SegNet from the paper by Google, I've got a problem implementing unpooling operation using argmax indices.

The full code:

import tensorflow as tf


def unpool(pool, ind, ksize=[1, 2, 2, 1], name=None):
    with tf.variable_scope('name') as scope:
        input_shape = tf.shape(pool)
        output_shape = [input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]]

        flat_input_size = tf.cumprod(input_shape)[-1]
        flat_output_shape = tf.stack([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]])

        pool_ = tf.reshape(pool, tf.stack([flat_input_size]))
        batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype),
                                        shape=tf.stack([input_shape[0], 1, 1, 1]))
        b = tf.ones_like(ind) * batch_range
        b = tf.reshape(b, tf.stack([flat_input_size, 1]))
        ind_ = tf.reshape(ind, tf.stack([flat_input_size, 1]))
        ind_ = tf.concat([b, ind_], 1)

        ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
        ret = tf.reshape(ret, tf.stack(output_shape))

        set_input_shape = pool.get_shape()
        set_output_shape = [set_input_shape[0], set_input_shape[1] * ksize[1], set_input_shape[2] * ksize[2], set_input_shape[3]]
        ret.set_shape(set_output_shape)
    return ret

with tf.Session() as sess:
    x = tf.random_normal([1, 4, 4, 1])
    y, ind = tf.nn.max_pool_with_argmax(
        x,
        ksize=[1, 2, 2, 1],
        strides=[1, 2, 2, 1],
        padding='SAME'
    )

    z = unpool(y, ind)

    x_, y_, z_ = sess.run([x, y, z])

For batch size 1 it works OK, but for batch size > 1 it crashes with the next issue:

2018-09-22 16:33:57.010504: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2018-09-22 16:33:57.082638: W tensorflow/core/framework/op_kernel.cc:1275] OP_REQUIRES failed at scatter_nd_op.cc:119 : Invalid argument: Invalid indices: [2,0] = [1, 21] does not index into [4,16]
Traceback (most recent call last):
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1278, in _do_call
    return fn(*args)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1263, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1350, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Invalid indices: [2,0] = [1, 21] does not index into [4,16]
     [[Node: name/ScatterNd = ScatterNd[T=DT_FLOAT, Tindices=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](name/concat, name/Reshape, name/Cast_1)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "tst.py", line 39, in <module>
    x_, y_, z_ = sess.run([x, y, z])
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 877, in run
    run_metadata_ptr)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1100, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1272, in _do_run
    run_metadata)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1291, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Invalid indices: [2,0] = [1, 21] does not index into [4,16]
     [[Node: name/ScatterNd = ScatterNd[T=DT_FLOAT, Tindices=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](name/concat, name/Reshape, name/Cast_1)]]

Caused by op 'name/ScatterNd', defined at:
  File "tst.py", line 37, in <module>
    z = unpool(y, ind)
  File "tst.py", line 20, in unpool
    ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 6788, in scatter_nd
    "ScatterNd", indices=indices, updates=updates, shape=shape, name=name)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 454, in new_func
    return func(*args, **kwargs)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3155, in create_op
    op_def=op_def)
  File "/home/vrudenko/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1717, in __init__
    self._traceback = tf_stack.extract_stack()

InvalidArgumentError (see above for traceback): Invalid indices: [2,0] = [1, 21] does not index into [4,16]
     [[Node: name/ScatterNd = ScatterNd[T=DT_FLOAT, Tindices=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](name/concat, name/Reshape, name/Cast_1)]]

Where could be a problem and how I can fix it?

Unpooling function was taken from this issue on github, but nothing is told about unpooling for batch there.

My tf.__version__ is 1.10.

@Tofik.AI witch Tensorflow version do you use? According to the latest documentation, it's incorrect. My implementation:

def unpool(pool, ind, ksize=[1, 2, 2, 1], name=None):
with tf.variable_scope('name') as scope:
    input_shape = tf.shape(pool)
    output_shape = [input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]]

    flat_input_size = tf.cumprod(input_shape)[-1]
    flat_output_shape = tf.stack([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]])

    pool_ = tf.reshape(pool, tf.stack([flat_input_size]))
    batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype),
                             shape=tf.stack([input_shape[0], 1, 1, 1]))
    b = tf.ones_like(ind) * batch_range
    b = tf.reshape(b, tf.stack([flat_input_size, 1]))
    ind_ = tf.reshape(ind, tf.stack([flat_input_size, 1]))
    ind_ = ind_ - b * tf.cast(flat_output_shape[1], tf.int64)
    ind_ = tf.concat([b, ind_], 1)

    ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
    ret = tf.reshape(ret, tf.stack(output_shape))

    set_input_shape = pool.get_shape()
    set_output_shape = [set_input_shape[0], set_input_shape[1] * ksize[1], set_input_shape[2] * ksize[2], set_input_shape[3]]
    ret.set_shape(set_output_shape)
return ret

tf.nn.max_pool_with_argmax, Performs max pooling on the input and outputs both max values and indices. The indices in argmax are flattened, so that a maximum value at position [b, y, x, Must be one of the following types: float32 , float64 , int32 , uint8 , int16 , int8  Unpooling of tensor that has been pooled using max_pool_with_argmax. This figure below illustrates it perfectly. More on max pooling here ::Max Pooling. The original paper on unpooling :: Unpooling. Some more resources one may find useful:: Broadcasting in Numpy. Scatter in tensorflow. Cocatanation in Tensorflow.

There is a repository that implements the unpool op in CUDA. The unpool_example.py file shows how to use the library. From initial testing it is about two times faster than composing existing tensorflow functions at inference (4 times during training).

Just use it like the following:

import unpool

#pool, inds = max_pool_with_inds
unpool_layer = unpool.unpool(pool, inds,
                             output_size=[height, width],
                             name="unpool")

Full disclosure, I authored this repo.

Unpooling layer in tensorflow · Issue #632 · tensorflow/addons , It would be nice to have in TensorFlow also the unpooling layer as it is uses the second output of tf.nn.max_pool_with_argmax (which are the indices of I get the following error for the implementation described above with  The indices in argmax are flattened, so that a maximum value at position [b, y, x, c] becomes flattened index: (y * width + x) * channels + c if include_batch_in_index is False; ((b * height + y) * width + x) * channels + c if include_batch_in_index is True. The indices returned are always in [0, height) x [0,

Your code is working fine:

import tensorflow as tf

def unpool(pool, ind, ksize=[1, 2, 2, 1], name=None):
    with tf.variable_scope('name') as scope:
        input_shape = tf.shape(pool)
        output_shape = [input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]]

        flat_input_size = tf.cumprod(input_shape)[-1]
        flat_output_shape = tf.stack([output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]])

        pool_ = tf.reshape(pool, tf.stack([flat_input_size]))
        batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype),
                                        shape=tf.stack([input_shape[0], 1, 1, 1]))
        b = tf.ones_like(ind) * batch_range
        b = tf.reshape(b, tf.stack([flat_input_size, 1]))
        ind_ = tf.reshape(ind, tf.stack([flat_input_size, 1]))
        ind_ = tf.concat([b, ind_], 1)

        ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
        ret = tf.reshape(ret, tf.stack(output_shape))

        set_input_shape = pool.get_shape()
        set_output_shape = [set_input_shape[0], set_input_shape[1] * ksize[1], set_input_shape[2] * ksize[2], set_input_shape[3]]
        ret.set_shape(set_output_shape)
    return ret


batch_size=10
with tf.Session() as sess:

    x = tf.random_normal([batch_size,16,16,1])
    y, ind = tf.nn.max_pool_with_argmax(
        x,
        ksize=[1, 2, 2, 1],
        strides=[1, 2, 2, 1],
        padding='SAME'
    )

    z = unpool(y, ind)
    x_, y_, z_=sess.run([x, y, z])



aa=x_[4,:,:,0]
bb=y_[4,:,:,0]
cc=z_[4,:,:,0]

You may update the tensorflow. I am using tensorflow 1.12.0

Error message for running tf.nn.max_pool_with_argmax() on CPU , Running tf.nn.max_pool_with_argmax() on CPU gives a very obscure error: I only compiled it for android after that so could be because you are trying to compile it for cuda as well this might have an impact on the unpool operation using such indices. https://github.com/nio1814/tensorflow/tree/​maxpoolwithargmax-cpu  In tensorflow the indices in argmax are flattened, so that a maximum value at position [b, y, x, c] becomes flattened index ((b * height + y) * width + x) * channels + c 2. Due to point 1, use broadcasting to appropriately place the values at their right locations !

Developers - Unpooling layer in tensorflow -, It would be nice to have in TensorFlow also the unpooling layer as it is described in I get the following error for the implementation described above with tf.nn.​max_pool_with_argmax indices are calculated as (y * w + x)  There is a bug in the indices returned from tf.nn.max_pool_with_argmax when a padding is applied. The indices with be based on the shape supplied to max_pool_with_argmax, instead of the shape+padding. Simple code example to reproduce: This prints "200, 300, 201, 299".

In Tensorflow, how to assign values in Tensor according to the , I do not how to obtain the flattened indices after the tf.nn.max_pool_with_argmax and assigning into the unpooling tensor in Tensorflow. Are there any performance gain/loss if one uses the second output of tf.nn.max_pool_with_argmax (which are the indices of the max pool) and uses it along with a tf.map_fn to achive a max unpooling? ️ 1

MaxUnPooling2DWithArgmax for Tensorflow 2.0 - Needed, So tensorflow 2.0 has tf.nn.max_pool_with_argmax , We need a working version of the reverse nearly all of the "FEW" implmentations that try to do UNPOOLING use the same '''Inversion of MaxPooling with indices. Tensorflow has tf.nn.max_pool_with_argmax which may be better optimized for what you're trying to do.. I'd also guess that using something like tf.scatter_nd, which modifies a tensor in-place at given indices, would be more efficient than comparing large sparse tensors using tf.where.

Comments
  • That works fine, thanks. But what was the problem? Can you post a link to TF documentation, so that I'd got the idea?