![](/img/trans.png)
[英]why does pytorch's utils.save_image() change the color of my image
[英]Why does torch.utils.save_image overwrite saved images in my folder?
我正在尝试对 10 张图像进行对抗性攻击,我需要将所有受干扰的图像保存在一个文件夹中。 所以,我用torch.utils.save_image
在pytorch它工作得很好。 我希望所有图像都保存在文件夹中,但相反,它们被覆盖,最后看到的图像是唯一保存的图像。 我有以下attack()
函数,它需要一个图像来扰乱
def attack(img, label, net, target=None, pixels=1, maxiter=75, popsize=400, verbose=False):
# img: 1*3*W*H tensor
# label: a number
targeted_attack = target is not None
target_calss = target if targeted_attack else label
bounds = [(0,32), (0,32), (0,255), (0,255), (0,255)] * pixels
popmul = max(1, popsize//len(bounds))
predict_fn = lambda xs: predict_classes(
xs, img, target_calss, net, target is None)
callback_fn = lambda x, convergence: attack_success(
x, img, target_calss, net, targeted_attack, verbose)
inits = np.zeros([popmul*len(bounds), len(bounds)])
count = 1
for init in inits:
for i in range(pixels):
init[i*5+0] = np.random.random()*32
init[i*5+1] = np.random.random()*32
init[i*5+2] = np.random.normal(128,127)
init[i*5+3] = np.random.normal(128,127)
init[i*5+4] = np.random.normal(128,127)
attack_result = differential_evolution(predict_fn, bounds, maxiter=maxiter, popsize=popmul,
recombination=1, atol=-1, callback=callback_fn, polish=False, init=inits)
attack_image = perturb_image(attack_result.x, img)
# attack_var = Variable(attack_image, volatile=True).cuda()
with torch.no_grad():
attack_var = attack_image.to(device)
predicted_probs = F.softmax(net(attack_var), dim=1).data.cpu().numpy()[0]
predicted_class = np.argmax(predicted_probs)
vutils.save_image(vutils.make_grid(attack_image, normalize=True, scale_each=True), 'result_img/adversarial' + str(count) + '.png')
vutils.save_image(vutils.make_grid(img, normalize=True, scale_each=True), 'result_img/original' + str(count) + '.png')
count = count + 1
if (not targeted_attack and predicted_class != label) or (targeted_attack and predicted_class == target_calss):
return 1, attack_result.x.astype(int)
return 0, [None]
下面是attack_all()
函数,它扰乱了一批图像(整个测试集),在我的例子中是 10 张图像。
def attack_all(net, loader, pixels=1, targeted=False, maxiter=75, popsize=400, verbose=False):
correct = 0
success = 0
for batch_idx, (input, target) in enumerate(loader):
# img_var = Variable(input, volatile=True).cuda()
with torch.no_grad():
img_var = input.to(device)
target = target
prior_probs = F.softmax(net(img_var), dim=1)
_, indices = torch.max(prior_probs, 1)
if target[0] != indices.data.cpu()[0]:
continue
correct += 1
target = target.numpy()
targets = [None] if not targeted else range(10)
for target_calss in targets:
if (targeted):
if (target_calss == target[0]):
continue
flag, x = attack(input, target[0], net, target_calss, pixels=pixels, maxiter=maxiter, popsize=popsize, verbose=verbose)
success += flag
if (targeted):
success_rate = float(success)/(9*correct)
else:
success_rate = float(success)/correct
if flag == 1:
print("success rate: %.4f (%d/%d) [(x,y) = (%d,%d) and (R,G,B)=(%d,%d,%d)]"%(
success_rate, success, correct, x[0],x[1],x[2],x[3],x[4]))
if correct == args.samples:
break
return success_rate
下面是main()
类,我在其中使用attack_all()
攻击 10 个图像。 我希望保存所有 10 张图像(原始图像和扰动图像),但只保存最后看到的图像。
def main():
print ("==> Loading data and model...")
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=tranfrom_test)
test_set = Cifar10Dataset(csv_file='mydata/cifar10.csv', root_dir = 'mydata/cifar_selected_10', transform = transform_test)
testloader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=True, num_workers=2)
class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/%s.t7'%args.model)
net = checkpoint['net']
net.cuda()
cudnn.benchmark = True
print ("==> Starting attack...")
results = attack_all(net, testloader, pixels=args.pixels, targeted=args.targeted, maxiter=args.maxiter, popsize=args.popsize, verbose=args.verbose)
print ("Final success rate: %.4f"%results)
所以我想出了如何自己解决它。
我注意到attack()
中的变量count
无论如何都不会增加。 相反,我在attack()
外设置count = 1
并在同一个attack()
内进行global count
。 这样,每次attack_all()
调用函数attack()
, count
值都可以改变并且不会保持不变。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.