简体   繁体   中英

How to class another loss function in GAN discriminator?

I am a interested in GAN. I tried to adjust the DCGAN's discriminator by this method below: https://github.com/vasily789/adaptive-weighted-gans/blob/main/aw_loss.py which name is aw method. So I find a DCGAN code in kaggle( https://www.kaggle.com/vatsalmavani/deep-convolutional-gan-in-pytorch ) and try to edit the discriminator by class the aw_loss.

Here is my code:

https://colab.research.google.com/drive/1AsZztd0Af0UMzBXXkI9QKQZhAUoK01bk?usp=sharing

it seems like I can not class the aw loss correctly. Because the discriminator's loss is still 0 when I training. Any one can help me. Please!在此处输入图像描述

In the code you provided, it does display the correct error when trying to use aw_method() , you should first instance the class as shown below after which you should be able to call the method.

aw_instance = aw_method()
aw_loss = aw_instance.aw_loss(D_real_loss, D_fake_loss, D_opt, D)

Notice that we are using default parameters for the class, not so familiar with aw loss to tell you if you should tweak that.

Regarding your discriminator's loss, correct code relies on aw_cost to work. It doesn't seem like your providing both losses from real and fake, so the discriminator is only learning to output 1's or 0's (which can be easily verified by printing those values or monitoring with wandb or similar tools). Again didn't go deep enough into the algorithm of the aw loss, so check this specifically.

Also could try to test as a linear combination of your normal D_loss = (D_fake_loss + D_real_loss + aw_loss) / 3 .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM