简体   繁体   中英

Tensorflow 2: apply one hot encoding on masks for semantic segmentation

I'm trying to process my ground truth images to create one hot encoded tensors:

def one_hot(img, nclasses):
  result = np.zeros((img.shape[0], img.shape[1], nclasses))
  img_unique = img.reshape(512*512, img.shape[2])
  unique = np.unique(img_unique, axis=0)
  for i in range(img.shape[0]):
    for j in range(img.shape[1]):
      for k, unique_val in enumerate(unique):
        if (np.array_equal(img[i,j], unique_val)):
          result[i,j,k] = 1
          break

  return result

This is creating WxHxN tensor from WxHx3 image. I really don't like such approach because of its performance. Could you advice more efficient way?

I tried to use tf.one_hot but it converts the image into WxHx3xN tensor.

For your specific scenario where the 3 classes are known this should work faster

def one_hot2(img):
    class1 = [255,0,0]
    class2 = [0,0,255]
    class3 = [255,255,255]  
    label = np.zeros_like(img)
    label[np.sum(img==np.array([[class2]]), 2)==3] = 1
    label[np.sum(img==np.array([[class3]]), 2)==3] = 2
    onehot = np.eye(3)[label]
    return onehot 

For a purely TF2.x approach, you could also do the following

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf

@tf.function # Remove this to see the tf.print array values
def get_one_hot():
  label_ids = [0,5,10]
  mask_orig = tf.constant([[0,10], [0,10]], dtype=tf.float32) # [2,2]
  mask_onehot = tf.concat([tf.expand_dims(tf.math.equal(mask_orig, label_id),axis=-1) for label_id in label_ids], axis=-1) # [2,2,2]
  mask_label_present = tf.reduce_any(mask_onehot, axis=[0,1]) # [2]
  
  tf.print('\n - label_ids:{}'.format(label_ids))
  tf.print('\n - mask_orig:\n{}\n'.format(mask_orig))
  for id_, label_id in enumerate(label_ids):
      tf.print(' - mask_onehot:{}\n{}'.format(label_id, mask_onehot[:,:,id_]))
  tf.print('\n - mask_label_present:\n ', mask_label_present)

get_one_hot()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM