简体   繁体   English

tf2.0中如何获取张量的值?

[英]How to get the value of tensor in tf2.0?

IMAGE_FEATURE_MAP = {
    'image/filename': tf.io.FixedLenFeature([], tf.string),
    'image/encoded': tf.io.FixedLenFeature([], tf.string),
    'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
    'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
    'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
    'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
    'image/object/class/text': tf.io.VarLenFeature(tf.string),
}
yolo_max_boxes = 100

def parse_tfrecord(tfrecord, class_table, size):

    x = tf.io.parse_single_example(tfrecord, IMAGE_FEATURE_MAP)
    x_train = tf.image.decode_jpeg(x['image/encoded'], channels=3)
    x_train = tf.image.resize(x_train, (size, size))
    class_text = tf.sparse.to_dense(
        x['image/object/class/text'], default_value='')
    labels = tf.cast(class_table.lookup(class_text), tf.float32)
    y_train = tf.stack([tf.sparse.to_dense(x['image/object/bbox/xmin']),
                        tf.sparse.to_dense(x['image/object/bbox/ymin']),
                        tf.sparse.to_dense(x['image/object/bbox/xmax']),
                        tf.sparse.to_dense(x['image/object/bbox/ymax']),
                        labels], axis=1)
    paddings = [[0, yolo_max_boxes - tf.shape(y_train)[0]], [0, 0]]

    y_train = tf.pad(y_train, paddings)

    return x_train, y_train


def load_tfrecord_dataset(file_pattern, class_file, size=416):

    LINE_NUMBER = -1

    class_table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
        class_file, tf.string, 0, tf.int64, LINE_NUMBER, delimiter="\n"), -1)

    files = tf.data.Dataset.list_files(file_pattern)
    dataset = files.flat_map(tf.data.TFRecordDataset)

    result = dataset.map(lambda x: parse_tfrecord(x, class_table, size))

    return result

def main():
    load_tfrecord_dataset('../data/facemask2020_train.tfrecord', '../data/mask2020.names', size=416)


if __name__ == '__main__':
    main()

When I print x ['image/filename'], I can only get the shape and data type of the tensor.当我打印 x ['image/filename'] 时,我只能得到张量的形状和数据类型。 I want to get detailed records about the file name in the tf.record file, but I can't see the specific value of this tensor.What should I do to check his specific value?想获取tf.record文件中关于文件名的详细记录,但是看不到这个张量的具体值,怎么查看他的具体值呢? I am a newbie, please help me.我是新手,请帮帮我。

There's no print in the code you posted.您发布的代码中没有print But I believe iterating through the dataset should work:但我相信遍历数据集应该可行:

def main():
    ds = load_tfrecord_dataset('../data/facemask2020_train.tfrecord',
                               '../data/mask2020.names', size=416)
    for r in ds:
         print(r['image/filename'])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM