简体   繁体   中英

Iterate over a tf.Tensor inside a tf function to generate a list of NamedTuple-based dataset items

I'm using typing.NamedTuple -based element types in a tf.data.Dataset . An example of this is below.

# You can run all the code in this question by pasting all
# the code blocks consecutively into a Python file

import tensorflow as tf
from typing import *
from random import *
from pprint import *

class Coord(NamedTuple):
    x: float
    y: float

    @classmethod
    def random(cls): return cls(gauss(10., 1.), gauss(10., 1.))

class Box(NamedTuple):
    min: Coord
    max: Coord

    @classmethod
    def random(cls): return cls(Coord.random(), Coord.random())

class Boxes(NamedTuple):
    boxes: List[Box]

    @classmethod
    def random(cls): return cls([Box.random() for _ in range(randint(3, 5))])

def test_dataset():
    for _ in range(randint(3, 5)): yield Boxes.random()

tf_dataset = tf.data.Dataset.from_generator(test_dataset, output_types=(tf.float32,))

As you may know, tf.data.Dataset.from_generator() converts the dataset elements (originally having the Boxes type) into a one-element tuple of tf.Tensor with a (None, 2, 2) shape. For example one element of the dataset might be the following item:

(<tf.Tensor: shape=(4, 2, 2), dtype=float32, numpy=
array([[[11.642379,  9.937152],
        [ 8.998009,  8.387287]],

       [[10.649337, 10.028358],
        [ 8.507834,  9.84779 ]],

       [[11.10263 , 11.3706  ],
        [ 9.20623 , 10.44905 ]],

       [[ 9.591406,  9.560486],
        [ 9.461394,  9.256082]]], dtype=float32)>,)

I have non- @tf.function -annotated regular Python functions that can transform the data in their original type, for example the following function:

def flip_boxes(boxes: Boxes):
    def flip_coord(c: Coord): return Coord(-c.x, c.y)
    def flip_box(b: Box): return Box(flip_coord(b.min), flip_coord(b.max))
    return Boxes(boxes=list(map(flip_box, boxes.boxes)))

I want to apply this Python function (and others similar to it) to this tf.data.Dataset via the tf.data.Dataset.map(map_func) function. Dataset.map expects map_func to be a function taking the members of the dataset element type in their tf.Tensor format. The original element type is Boxes which has one member, originally boxes: List[Box] . That list is transformed into the (4, 2, 2) -shape Tensor above when the dataset is created. It is not transformed back when tf.data.Dataset.map() calls map_func , the Tensor is directly passed as the first parameter to map_func . (If Boxes had more members those would be passed as separate parameters to map_func and they are not passed as a single tuple.)

Question: What adapter function do I implement to make a regular Python function (like flip_boxes ) usable with tf.data.Dataset.map() ?

I tried iterating over and using tf.split to recover a List[Boxes] from the input tf.Tensor but I ran into the error messages listed below as comments.

# Question: How do I implement this function?
def to_tf_mappable_function(fn: Callable) -> Callable:

    def function(tensor: tf.Tensor):
        boxes: List[Box] = [Box(Coord(10.0, 0.0), Coord(10.0, 0.0)), Box(Coord(10.0, 0.0), Coord(10.0, 0.0))]
        # TODO calculate `boxes` from `tensor`, not use this dummy constant above

        # Trivial Python code does not work, it results in this error on the commented-out line:
        #   OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed:
        #   AutoGraph is disabled in this function. Try decorating it directly with @tf.function.
        # boxes = [Box(Coord(row[0][0], row[0][1]), Coord(row[1][0], row[1][1])) for row in tensor]
        # Decorating any of flip_boxes, to_tf_mappable_function and to_tf_mappable_function.function
        # does not eliminate the error.

        # I thought tf.split might help, but it results in this error on the commented-out line:
        #   ValueError: Rank-0 tensors are not supported as the num_or_size_splits argument to split.
        #   Argument provided: Tensor("cond/Identity:0", shape=(), dtype=int32)
        # boxes = tf.split(tensor, len(tensor))

        return fn(Boxes(boxes))

    return function

tf_dataset = tf_dataset.map(to_tf_mappable_function(flip_boxes))
# The line above should be morally equivalent to `dataset = map(flip_boxes, dataset)`,
# given a `dataset: Iterable[Boxes]` and the builtin `map` function in Python.

Perhaps I'm not asking the right question but please give me some slack. * The high-level task is to apply flip_boxes and similar functions to a tf.data.Dataset in an efficient way * The place where I'm stuck is recovering a List[Box] from a tf.Tensor that's shaped exactly like a list of box coordinates, so maybe my question should be scoped to this problem.

I'm not sure if you are looking for something more general or not, but for the exact question you ask here this seems to be one of possible ways to implement it:

# Helper function to translate from tensor back to Boxes type
def boxes_from_tensor(t: tf.Tensor) -> Boxes:
    n_boxes = t.shape[0]
    t = t.numpy()
    boxes = Boxes(boxes=[Box(Coord(t[i,0,0], t[i,0,1]), Coord(t[i,1,0], t[i,1,1])) for i in range(n_boxes)])
    return boxes

def to_tf_mappable_function(fn: Callable) -> Callable:
    def function(tensor: tf.Tensor):
        return tf.py_function(lambda t: fn(boxes_from_tensor(t)), [tensor], tensor.dtype)
    return function

tf_dataset = tf_dataset.map(to_tf_mappable_function(flip_boxes))
list(tf_dataset)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM