[英]How do I apply custom data augmentation as preprocessing layer in tensorflow?
我正在對頻譜圖圖像執行數據增強,並在 tensorflow 中作為預處理層的一部分屏蔽時間和頻率。 我遇到以下情況:
'tensorflow.python.framework.ops.EagerTensor' object does not support item assignment
這是我使用的代碼:
def random_mask_time(img):
MAX_OCCURENCE = 5
MAX_LENGTH = 10
nums = random.randint(0,MAX_OCCURENCE) # number of masks
for n in range(nums):
length = random.randint(0, MAX_LENGTH) # number of columns to mask (up to 20px in time)
pos = random.randint(0, img.shape[0]-MAX_LENGTH) # position to start masking
img[:,pos:(pos+length),:] = 0
return img
def layer_random_mask_time():
return layers.Lambda(lambda x: random_mask_time(x))
rnd_time = layer_random_mask_time()
data_augmentation = tf.keras.Sequential([
rnd_time,
rnd_freq,
layers.RandomCrop(input_shape[1], input_shape[0]),
])
然后我將它用作我的 keras 順序模型的一部分。
我知道張量是不可變的,但是我怎樣才能屏蔽掉圖像呢?
我將此用作參考: https ://www.tensorflow.org/tutorials/images/data_augmentation#custom_data_augmentation
嘗試這樣的事情:
import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt
def random_mask_time(img):
MAX_OCCURENCE = 5
MAX_LENGTH = 10
nums = tf.random.uniform((), minval = 0, maxval = MAX_OCCURENCE, dtype=tf.int32) # number of masks
for n in tf.range(nums):
length = tf.random.uniform((), minval = 0, maxval = MAX_LENGTH, dtype=tf.int32) # number of columns to mask (up to 20px in time)
pos = tf.random.uniform((), minval = 0, maxval = img.shape[1]-MAX_LENGTH, dtype=tf.int32) # position to start masking
img = tf.concat([img[:, :, :pos,:], img[:, :, pos:(pos+length),:]*0, img[:, :, (pos+length):,:]], axis=2)
return img
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
batch_size = 1
ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
seed=123,
image_size=(128, 128),
batch_size=batch_size)
data_augmentation = tf.keras.Sequential([
tf.keras.layers.Lambda(lambda x: random_mask_time(x)),
tf.keras.layers.RandomCrop(128, 128),
])
image, _ = next(iter(ds.take(1)))
image = data_augmentation(image)
plt.imshow(image[0].numpy() / 255)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.