繁体   English   中英

Tensorflow / keras:如何使用自定义图像文件(.FITS 图像)训练 CNN

[英]Tensorflow / keras: How to train CNN with custom image files (.FITS images)

我目前正在尝试使用 .FITS 图像文件训练 CNN (EfficientNet),但我相信这也适用于其他图像类型。 这种图像需要astropy库才能打开,在我的例子中,为了访问图像数据,我只需键入:

from astropy.io import fits
path = "path/to/file.fits"
hdul = fits.open(path)
image = hdul[1].data

这个可变image然后将具有类型numpy.ndarray 我首先尝试使用 keras 的image_dataset_from_directory ,正如预期的那样,它没有成功。 然后我在这里查看了 tf.data:https://www.tensorflow.org/tutorials/load_data/images#using_tfdata_for_finer_control tf.data 我尝试创建一个类似的管道,直到decode_img function 都成功了。由于我不处理 jpeg,我尝试制定一个解决方法,以便我得到:

data_dir = home/astro/train
class_names = np.array(sorted([item.name for item in data_dir.glob('*')]))
# class_names = ["stars", "galaxies"]

def get_label(file_path):
    parts = tf.strings.split(file_path, os.path.sep)
    one_hot = parts[-2] == class_names
    return tf.argmax(one_hot)

def decode_img(img):
    hdul = fits.open(img)
    data = hdul[1].data
    data = data.reshape((data.shape[0], data.shape[1], 1))
    data = np.pad(data, [(0,0), (0,0), (0, 2)], 'constant') # padding to create 3 channels
    img = tf.convert_to_tensor(data, np.float32)
    return tf.image.resize(img, [img_height, img_width])

def process_path(file_path):
    label = get_label(file_path)
    img = decode_img(file_path)
    return img, label

它实际上效果很好,在某种程度上,当我打印process_path时,我得到两个张量,一个用于图像,一个用于 label,具有我想要的正确形状和值。

问题:

按照教程,当我到达:

AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)

我收到以下错误:

TypeError                                 Traceback (most recent call last)
 in 
      1 AUTOTUNE = tf.data.experimental.AUTOTUNE
      2 
----> 3 train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
      4 val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in map(self, map_func, num_parallel_calls, deterministic)
   1700           num_parallel_calls,
   1701           deterministic,
-> 1702           preserve_cardinality=True)
   1703 
   1704   def flat_map(self, map_func):

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, input_dataset, map_func, num_parallel_calls, deterministic, use_inter_op_parallelism, preserve_cardinality, use_legacy_function)
   4082         self._transformation_name(),
   4083         dataset=input_dataset,
-> 4084         use_legacy_function=use_legacy_function)
   4085     if deterministic is None:
   4086       self._deterministic = "default"

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in __init__(self, func, transformation_name, dataset, input_classes, input_shapes, input_types, input_structure, add_to_graph, use_legacy_function, defun_kwargs)
   3369       with tracking.resource_tracker_scope(resource_tracker):
   3370         # TODO(b/141462134): Switch to using garbage collection.
-> 3371         self._function = wrapper_fn.get_concrete_function()
   3372         if add_to_graph:
   3373           self._function.add_to_graph(ops.get_default_graph())

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in get_concrete_function(self, *args, **kwargs)
   2937     """
   2938     graph_function = self._get_concrete_function_garbage_collected(
-> 2939         *args, **kwargs)
   2940     graph_function._garbage_collector.release()  # pylint: disable=protected-access
   2941     return graph_function

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   2904       args, kwargs = None, None
   2905     with self._lock:
-> 2906       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
   2907       seen_names = set()
   2908       captured = object_identity.ObjectIdentitySet(

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3211 
   3212       self._function_cache.missed.add(call_context_key)
-> 3213       graph_function = self._create_graph_function(args, kwargs)
   3214       self._function_cache.primary[cache_key] = graph_function
   3215       return graph_function, args, kwargs

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3073             arg_names=arg_names,
   3074             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3075             capture_by_value=self._capture_by_value),
   3076         self._function_attributes,
   3077         function_spec=self.function_spec,

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    984         _, original_func = tf_decorator.unwrap(python_func)
    985 
--> 986       func_outputs = python_func(*func_args, **func_kwargs)
    987 
    988       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in wrapper_fn(*args)
   3362           attributes=defun_kwargs)
   3363       def wrapper_fn(*args):  # pylint: disable=missing-docstring
-> 3364         ret = _wrapper_helper(*args)
   3365         ret = structure.to_tensor_list(self._output_structure, ret)
   3366         return [ops.convert_to_tensor(t) for t in ret]

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py in _wrapper_helper(*args)
   3297         nested_args = (nested_args,)
   3298 
-> 3299       ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
   3300       # If `func` returns a list of tensors, `nest.flatten()` and
   3301       # `ops.convert_to_tensor()` would conspire to attempt to stack

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
    256       except Exception as e:  # pylint:disable=broad-except
    257         if hasattr(e, 'ag_error_metadata'):
--> 258           raise e.ag_error_metadata.to_exception(e)
    259         else:
    260           raise

TypeError: in user code:

    :17 process_path  *
        img = decode_img(file_path)
    :7 decode_img  *
        hdul = fits.open(img)
    /home/marcostidball/anaconda3/lib/python3.7/site-packages/astropy/io/fits/hdu/hdulist.py:154 fitsopen  *
        if not name:
    /home/marcostidball/anaconda3/lib/python3.7/site-packages/tensorflow/python/autograph/operators/logical.py:29 not_
        return _tf_not(a)
    /home/marcostidball/anaconda3/lib/python3.7/site-packages/tensorflow/python/autograph/operators/logical.py:35 _tf_not
        return gen_math_ops.logical_not(a)
    /home/marcostidball/anaconda3/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:5481 logical_not
        "LogicalNot", x=x, name=name)
    /home/marcostidball/anaconda3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:493 _apply_op_helper
        (prefix, dtypes.as_dtype(input_arg.type).name))

    TypeError: Input 'x' of 'LogicalNot' Op has type string that does not match expected type of bool.

有谁知道解决这个问题的方法? 我四处寻找直接使用 numpy 数组训练 CNN 的方法,例如我在进行张量转换之前得到的数组,并找到了一些使用 MNIST 和独立 keras 的示例。我想应用通常的数据增强和批量训练,不过,我不确定是否可以按照我所看到的进行操作。

非常感谢!

我遇到了同样的问题。 认为问题是(但我不确定)是当你制作train_ds时它说的是“这将在我们真正想要下一批时发生”。 astropy.io.fits期望在它被调用时运行,所以当它给出的文件名尚不存在(或者是占位符)时会抱怨。

我想出的解决方案是编写代码以在不使用astropy的情况下加载适合的文件。 你可以从这个GitHub Repo得到它。 如您所愿,它可能有问题,但它对我打开适合的图像很有用(我还没有想出表格)。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM