[英]Creating new dictionary with subset data from another dictionary
我有一个 python 字典,其中包含两个键image
和label
。 image
是图像像素的 numpy 数组, label
是图像对应的 label(0 到 9 之间的整数)。 我正在尝试创建一个新字典,其中仅包含原始数据中的某些数据,每个 label 中只有 50 个图像。
我的直觉说有一种简单的方法可以做到这一点,但我现在能想到的唯一方法就是做多个 for 循环。
创建原始字典:
import tensorflow_datasets as tfds
import jax.numpy as jnp
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()
train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
# Convert to floating-points
train_ds['image'] = jnp.float32(train_ds['image']) / 255.0
我想要的是一个新的train_ds2
字典,其中train_ds
中的 10 个标签中的每个标签只有 50 个图像
编辑:
我正在添加我的尝试(尽管我认为必须有更好的方法):
t_im = np.zeros((500,28,28,1))
t_lbl = np.zeros(500)
for k in tqdm(range(10)):
while i < 50:
for j in range(len(train_ds['label'])):
if train_ds['label'][j] == k:
t_im[k*50 + i,:,:,:] = train_ds['image'][j,:,:,:]
t_lbl[k*50 + i] = train_ds['label'][j]
i += 1
train_ds['image']
和train_ds['label']
都是jaxlib.xla_extension.DeviceArray
的实例。 您可以按如下方式对它们进行切片:
import random
import jax.numpy as jnp
idx = jnp.array(random.sample(range(len(train_ds['image'])), 50))
train_ds['image'][:50] # Get 50 items from start
train_ds['image'][idx] # Get arbitrary 50 items
所以要创建一个新字典,你只需要做类似的事情
train_df_part = {
'image': train_ds['image'][:50],
'label': train_ds['label'][:50]
}
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.