简体   繁体   English

InvalidArgumentError 在 TensorFlow 2.3.1 数据集图中生成字典列表时

[英]InvalidArgumentError when yielding List of Dicts inside TensorFlow 2.3.1 Dataset graph

Problem Statement问题陈述

I have to read from Google Cloud Storage, several JSON Lines files .jsonl .我必须从 Google Cloud Storage 中读取几个 JSON Lines 文件.jsonl In order to do this, I have created a dataset from the records I want to read, which is a numpy array containing [[<gs:// url>, id], ...] where id is row number to check which line is train/test/validation.为了做到这一点,我从我想要读取的记录中创建了一个数据集,它是一个包含[[<gs:// url>, id], ...]numpy array [[<gs:// url>, id], ...]其中id是行号以检查哪个线是训练/测试/验证。

Code代码

The main function, which creates the TF Dataset from a generator which yields the previously described np.ndarray and then runs a map function to download and parse the file is: main 函数从generator创建TF Datasetgenerator之前描述的np.ndarray ,然后运行 ​​map 函数来下载和解析文件:

def load_dataset(records: np.ndarray) -> tf.data.Dataset:
    """Create Tensorflow Dataset MapDataset (generator) from a list of gs:// data URL.

    Args:
        records (np.ndarray): List of strings, which are gs://<foo>/foo<N>/*.jsonl.gz files

    Returns:
        tf.data.Dataset: MapDataset generator which can be used for training Keras models.
    """
    dataset = tf.data.Dataset.from_generator(lambda: _generator(records), (tf.string, tf.int8))
    return dataset


def _generator(records):
    for r in records:
        yield r[0], r[1]

As you can see, the generator is simply iterating through the np.ndarray to get url and a 'line index'如您所见, generator只是简单地遍历np.ndarray以获取url'line index'

Then I have to load and preprocess the file from the URL to get a list of the json -> Dict objects.然后我必须从 URL load and preprocess文件以获取json -> Dict对象的列表。

def _load_and_preprocess(filepath, selected_sample):
    """Read a file GCS or local path and process it into a tensor

    Args:
        path (tensor): path string, pointer to GCS or local path

    Returns:
        tensor: processed input
    """
    sample_raw_input = tf.io.read_file(filepath)
    uncompressed_inputs = tf.py_function(_get_uncompressed_inputs, [sample_raw_input], tf.string)
    sample = tf.py_function(_load_sampled_sample, [uncompressed_inputs, selected_sample], tf.float32) #This `tf.float32` is definitely wrong
    return sample #This is not a tensor, but a List of Dictionaries which I will process later


def _get_uncompressed_inputs(record):
    return zlib.decompress(record.numpy(), 16 + zlib.MAX_WBITS)


def _load_sampled_sample(inputs: Iterable, selected_sample: List[int]) -> List[Dict[str, str]]:
    if not tf.executing_eagerly():
        raise RuntimeError("TensorFlow must be executing eagerly.")
    inputs = inputs.numpy()
    selected_sample = selected_sample.numpy()
    sample = _load__sampled_sample_from_jsonl(inputs, selected_sample)
    return sample


def _load__sampled_sample_from_jsonl(jsonl: bytes, selected_sample: List[int]) -> List[Dict[str, str]]:
    json_lines = _read_jsonl(jsonl).split("\n")
    sample = list()
    for n, sample_json in enumerate(json_lines):
        sample_obj = _read_json(sample_json) if n in selected_sample else None
        if sample_obj:
            sample.append(sample_obj)
    return sample


def _read_jsonl(jsonl: bytes) -> str:
    return jsonl.decode()

Executing执行

I then create the dataset with the above code, and try to retrieve a single sample from it to test.然后我使用上述代码创建数据集,并尝试从中检索单个样本进行测试。

val_ds = load_dataset(validation_records)
samples = tf.data.experimental.get_single_element(
    val_ds
) # This should be a list of Dicts

Which raises :其中raises

InvalidArgumentError: ValueError: Attempt to convert a value ({...}) with an unsupported type (<class 'dict'>) to a Tensor.
# ... are the dict values, which is really big so I've shortened it to `...`
Traceback (most recent call last):

  File "/home/victor/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 242, in __call__
    return func(device, token, args)

  File "/home/victor/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 140, in __call__
    outputs = [

  File "/home/victor/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 141, in <listcomp>
    _maybe_copy_to_context_device(self._convert(x, dtype=dtype),

  File "/home/victor/anaconda3/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 120, in _convert
    return ops.convert_to_tensor(value, dtype=dtype)

  File "/home/victor/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 1499, in convert_to_tensor
    ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)

  File "/home/victor/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py", line 338, in _constant_tensor_conversion_function
    return constant(v, dtype=dtype, name=name)

  File "/home/victor/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py", line 263, in constant
    return _constant_impl(value, dtype, shape, name, verify_shape=False,

  File "/home/victor/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py", line 275, in _constant_impl
    return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)

  File "/home/victor/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py", line 300, in _constant_eager_impl
    t = convert_to_eager_tensor(value, ctx, dtype)

  File "/home/victor/anaconda3/lib/python3.8/site-packages/tensorflow/python/framework/constant_op.py", line 98, in convert_to_eager_tensor
    return ops.EagerTensor(value, ctx.device_name, dtype)

ValueError: Attempt to convert a value ({...}) with an unsupported type (<class 'dict'>) to a Tensor.
# ... are the dict values, which is really big so I've shortened it to `...`


     [[{{node EagerPyFunc_1}}]] [Op:DatasetToSingleElement]

Conclusion结论

Is there any way I can work with List of Dicts without Eager execution (which is not allowed from TF Dataset)?有什么方法可以在没有急切执行的情况下使用字典列表(TF 数据集不允许)?

This list of dicts is not the input for my model, however, I simply cannot work with it in the preprocessing function because this error is raised before passing the values to any other function.这个字典列表不是我模型的输入,但是,我根本无法在preprocessing函数中使用它,因为在将值传递给任何其他函数之前会引发此错误。

Additional Informations:附加信息:

  • Python Version: 3.8 Python版本: 3.8
  • Tensorflow Version: 2.3.1 Tensorflow 版本: 2.3.1

Ok, I think I have fixed it with running eagerly a dataset.map function:好的,我想我已经通过急切地运行dataset.map函数来修复它:

dataset.map(lambda file, samples: tf.py_function(_load_and_preprocess, [file, samples], tf.variant))

Which is described here: How can you map values in a tf.data.Dataset using a dictionary此处描述: How can you map values in a tf.data.Dataset using a dictionary

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM