简体   繁体   English

无法重新训练实例分段模型

[英]Unable to retrain the instance segmentation model

Im trying to train the instance segmentation model. 我试图训练实例分割模型。 Im using the following code to generate the tfrecord. 我使用以下代码生成tfrecord。

flags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw pet dataset.')
flags.DEFINE_string('output_dir', '', 'Path to directory to output TFRecords.')
flags.DEFINE_string('label_map_path', 'data/pet_label_map.pbtxt',
                    'Path to label map proto')
flags.DEFINE_boolean('faces_only', True, 'If True, generates bounding boxes '
                     'for pet faces.  Otherwise generates bounding boxes (as '
                     'well as segmentations for full pet bodies).  Note that '
                     'in the latter case, the resulting files are much larger.')
flags.DEFINE_string('mask_type', 'png', 'How to represent instance '
                    'segmentation masks. Options are "png" or "numerical".')
FLAGS = flags.FLAGS


def get_class_name_from_filename(file_name):

  match = re.match(r'([A-Za-z_]+)(_[0-9]+\.jpg)', file_name, re.I)
  return match.groups()[0]


def dict_to_tf_example(data,
                       mask_path,
                       label_map_dict,
                       image_subdirectory,
                       ignore_difficult_instances=False,
                       faces_only=True,
                       mask_type='png'):

  img_path = os.path.join(image_subdirectory, data['filename'])
  with tf.gfile.GFile(img_path, 'rb') as fid:
    encoded_jpg = fid.read()
  encoded_jpg_io = io.BytesIO(encoded_jpg)
  image = PIL.Image.open(encoded_jpg_io)
  if image.format != 'JPEG':
    raise ValueError('Image format not JPEG')
  key = hashlib.sha256(encoded_jpg).hexdigest()

  with tf.gfile.GFile(mask_path, 'rb') as fid:
    encoded_mask_png = fid.read()
  encoded_png_io = io.BytesIO(encoded_mask_png)
  mask = PIL.Image.open(encoded_png_io)
  if mask.format != 'PNG':
    raise ValueError('Mask format not PNG')

  mask_np = np.asarray(mask)
  nonbackground_indices_x = np.any(mask_np != 2, axis=0)
  nonbackground_indices_y = np.any(mask_np != 2, axis=1)
  nonzero_x_indices = np.where(nonbackground_indices_x)
  nonzero_y_indices = np.where(nonbackground_indices_y)

  width = int(data['size']['width'])
  height = int(data['size']['height'])

  xmins = []
  ymins = []
  xmaxs = []
  ymaxs = []
  classes = []
  classes_text = []
  truncated = []
  poses = []
  difficult_obj = []
  masks = []
  if 'object' in data:
    for obj in data['object']:
      difficult = bool(int(obj['difficult']))
      if ignore_difficult_instances and difficult:
        continue
      difficult_obj.append(int(difficult))

      if faces_only:
        xmin = float(obj['bndbox']['xmin'])
        xmax = float(obj['bndbox']['xmax'])
        ymin = float(obj['bndbox']['ymin'])
        ymax = float(obj['bndbox']['ymax'])
      else:
        xmin = float(np.min(nonzero_x_indices))
        xmax = float(np.max(nonzero_x_indices))
        ymin = float(np.min(nonzero_y_indices))
        ymax = float(np.max(nonzero_y_indices))

      xmins.append(xmin / width)
      ymins.append(ymin / height)
      xmaxs.append(xmax / width)
      ymaxs.append(ymax / height)
      class_name = get_class_name_from_filename(data['filename'])
      classes_text.append(class_name.encode('utf8'))
      classes.append(label_map_dict[class_name])
      truncated.append(int(obj['truncated']))
      poses.append(obj['pose'].encode('utf8'))
      if not faces_only:
        mask_remapped = (mask_np != 2).astype(np.uint8)
        masks.append(mask_remapped)

  feature_dict = {
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/source_id': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
      'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
      'image/encoded': dataset_util.bytes_feature(encoded_jpg),
      'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
      'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
      'image/object/class/label': dataset_util.int64_list_feature(classes),
      'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
      'image/object/truncated': dataset_util.int64_list_feature(truncated),
      'image/object/view': dataset_util.bytes_list_feature(poses),
  }
  if not faces_only:
    if mask_type == 'numerical':
      mask_stack = np.stack(masks).astype(np.float32)
      masks_flattened = np.reshape(mask_stack, [-1])
      feature_dict['image/object/mask'] = (
          dataset_util.float_list_feature(masks_flattened.tolist()))
    elif mask_type == 'png':
      encoded_mask_png_list = []
      for mask in masks:
        img = PIL.Image.fromarray(mask)
        output = io.BytesIO()
        img.save(output, format='PNG')
        encoded_mask_png_list.append(output.getvalue())
      feature_dict['image/object/mask'] = (
          dataset_util.bytes_list_feature(encoded_mask_png_list))

  example = tf.train.Example(features=tf.train.Features(feature=feature_dict))
  return example


def create_tf_record(output_filename,
                     label_map_dict,
                     annotations_dir,
                     image_dir,
                     examples,
                     faces_only=True,
                     mask_type='png'):

  writer = tf.python_io.TFRecordWriter(output_filename)
  for idx, example in enumerate(examples):
    if idx % 100 == 0:
      logging.info('On image %d of %d', idx, len(examples))
    xml_path = os.path.join(annotations_dir, 'xmls', example + '.xml')
    mask_path = os.path.join(annotations_dir, 'trimaps', example + '.png')

    if not os.path.exists(xml_path):
      logging.warning('Could not find %s, ignoring example.', xml_path)
      continue
    with tf.gfile.GFile(xml_path, 'r') as fid:
      xml_str = fid.read()
    xml = etree.fromstring(xml_str)
    data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']

    try:
      tf_example = dict_to_tf_example(
          data,
          mask_path,
          label_map_dict,
          image_dir,
          faces_only=faces_only,
          mask_type=mask_type)
      writer.write(tf_example.SerializeToString())
    except ValueError:
      logging.warning('Invalid example: %s, ignoring.', xml_path)

  writer.close()


# TODO(derekjchow): Add test for pet/PASCAL main files.
def main(_):
  data_dir = FLAGS.data_dir
  label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)

  logging.info('Reading from Pet dataset.')
  image_dir = os.path.join(data_dir, 'images')
  annotations_dir = os.path.join(data_dir, 'annotations')
  examples_path = os.path.join(annotations_dir, 'trainval.txt')
  examples_list = dataset_util.read_examples_list(examples_path)

  # Test images are not included in the downloaded data set, so we shall perform
  # our own split.
  random.seed(42)
  random.shuffle(examples_list)
  num_examples = len(examples_list)
  num_train = int(0.7 * num_examples)
  train_examples = examples_list[:num_train]
  val_examples = examples_list[num_train:]
  logging.info('%d training and %d validation examples.',
               len(train_examples), len(val_examples))

  train_output_path = os.path.join(FLAGS.output_dir, 'pet_train.record')
  val_output_path = os.path.join(FLAGS.output_dir, 'pet_val.record')
  if FLAGS.faces_only:
    train_output_path = os.path.join(FLAGS.output_dir,
                                     'pet_train_with_masks.record')
    val_output_path = os.path.join(FLAGS.output_dir,
                                   'pet_val_with_masks.record')
  create_tf_record(
      train_output_path,
      label_map_dict,
      annotations_dir,
      image_dir,
      train_examples,
      faces_only=FLAGS.faces_only,
      mask_type=FLAGS.mask_type)
  create_tf_record(
      val_output_path,
      label_map_dict,
      annotations_dir,
      image_dir,
      val_examples,
      faces_only=FLAGS.faces_only,
      mask_type=FLAGS.mask_type)


if __name__ == '__main__':
  tf.app.run()

The dataset that Im using to train has 37 classes with images and masks. 我用来训练的数据集有37个带图像和面具的类。 The dataset is in here 数据集在这里

However when i try to train, Im getting the following error. 但是,当我尝试训练时,我得到以下错误。

Traceback (most recent call last): File "train.py", line 167, in tf.app.run() File "/anaconda3/envs/conda/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 126, in run _sys.exit(main(argv)) File "train.py", line 163, in main worker_job_name, is_chief, FLAGS.train_dir) File "/Users/Documents/research/models/research/object_detection/trainer.py", line 275, in train clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) File "/Users/Documents/research/models/research/slim/deployment/model_deploy.py", line 193, in create_clones outputs = model_fn(*args, **kwargs) File "/Users/Documents/research/models/research/object_detection/trainer.py", line 200, in _create_losses losses_dict = detection_model.loss(prediction_dict, true_image_shapes) File "/Users/Documents/research/models/research/object_detection/meta_architectures/faster_rcnn_meta_arch.py", line 1608, in loss groundtruth_masks_list, File "/Users/Documents/research/models/research/object_detectio 回溯(最近一次调用最后一次):文件“train.py”,第167行,在tf.app.run()文件中“/anaconda3/envs/conda/lib/python3.6/site-packages/tensorflow/python/platform /app.py“,第126行,在运行_sys.exit(main(argv))文件”train.py“,第163行,在主worker_job_name中,is_chief,FLAGS.train_dir)文件”/ Users / Documents / research / models /research/object_detection/trainer.py“,第275行,in train clones = model_deploy.create_clones(deploy_config,model_fn,[input_queue])File”/Users/Documents/research/models/research/slim/deployment/model_deploy.py“ ,第193行,在create_clones中输出= model_fn(* args,** kwargs)文件“/Users/Documents/research/models/research/object_detection/trainer.py”,第200行,在_create_losses中loss_dict = detection_model.loss(prediction_dict, true_image_shapes)文件“/Users/Documents/research/models/research/object_detection/meta_architectures/faster_rcnn_meta_arch.py​​”,第1608行,丢失groundtruth_masks_list,文件“/ Users / Documents / research / models / research / object_detectio n/meta_architectures/faster_rcnn_meta_arch.py", line 1837, in _loss_box_classifier raise ValueError('Groundtruth instance masks not provided. n / meta_architectures / faster_rcnn_meta_arch.py​​“,第1837行,在_loss_box_classifier中引发ValueError(未提供的Groundtruth实例掩码)。 ' ValueError: Groundtruth instance masks not provided. 'ValueError:未提供Groundtruth实例掩码。 Please configure 请配置

How can I be able to sort it out? 我怎样才能解决它?

In my case, the tfrecord creation was fine. 就我而言, tfrecord创作很好。 Similar to that which the OP has posted. 与OP发布的相似。 The solution to the problem was that I needed to add two lines in the pipeline.config file. 该问题的解决方案是我需要在pipeline.config文件中添加两行。

train_input_reader {
  label_map_path: "path/to/pbtxt"
  tf_record_input_reader {
    input_path: "path/to/traindata.tfrecord"
  }
  load_instance_masks: true
  mask_type: PNG_MASKS
}

The two lines load_instance_masks: true and mask_type: PNG_MASKS must be added to the train_input_reader . 必须将两行load_instance_masks: truemask_type: PNG_MASKS添加到train_input_reader This doesn't come by default in the pipeline.config file from model zoo. 默认情况下,这不是来自model zoo的pipeline.config文件。

Hope this helps. 希望这可以帮助。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM