简体   繁体   中英

Deeplearning4j neural network only predicting 1 class

For the past week or so, I have been trying to get a neural network to function using RGB images, but no matter what I do it seems to only be predicting one class. I have read all the links I could find with people encountering this problem and experimented with a lot of different things, but it always ends up predicting only one out of the two output classes. I have checked the batches going in to the model, I have increased the size of the dataset, I have increased the original pixel size(28*28) to 56*56, increased epochs, done a lot of model tuning and I have even tried a simple non-convolutional neural network as well as dumbing down my own CNN model, yet it changes nothing.

I have also checked into the structure of how the data is passed in for the training set(specifically imageRecordReader), but this input structure(in terms of folder structure and how the data is passed into the training set) works perfectly when given gray-scale images(as it originally was created with a 99% accuracy on the MNIST dataset).

Some context: I use the following folder names as my labels, ie folder(0), folder(1) for both training and testing data as there will only be two output classes. The training set contains 320 images of class 0 and 240 images of class 1, whereas the testing set is made up of 79 and 80 images respectively.

Code below:

private static final Logger log = LoggerFactory.getLogger(MnistClassifier.class);
private static final String basePath = System.getProperty("java.io.tmpdir") + "/ISIC-Images";

public static void main(String[] args) throws Exception {
    int height = 56;
    int width = 56;
    int channels = 3; // RGB Images
    int outputNum = 2; // 2 digit classification
    int batchSize = 1;
    int nEpochs = 1;
    int iterations = 1;
    int seed = 1234;
    Random randNumGen = new Random(seed);

    // vectorization of training data
    File trainData = new File(basePath + "/Training");
    FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // parent path as the image label
    ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
    trainRR.initialize(trainSplit);
    DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);

    // vectorization of testing data
    File testData = new File(basePath + "/Testing");
    FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
    ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
    testRR.initialize(testSplit);
    DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);

    log.info("Network configuration and training...");
    Map<Integer, Double> lrSchedule = new HashMap<>();
    lrSchedule.put(0, 0.06); // iteration #, learning rate
    lrSchedule.put(200, 0.05);
    lrSchedule.put(600, 0.028);
    lrSchedule.put(800, 0.0060);
    lrSchedule.put(1000, 0.001);

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(seed)
        .l2(0.0008)
        .updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, lrSchedule)))
        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
        .weightInit(WeightInit.XAVIER)
        .list()
        .layer(0, new ConvolutionLayer.Builder(5, 5)
            .nIn(channels)
            .stride(1, 1)
            .nOut(20)
            .activation(Activation.IDENTITY)
            .build())
        .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
            .kernelSize(2, 2)
            .stride(2, 2)
            .build())
        .layer(2, new ConvolutionLayer.Builder(5, 5)
            .stride(1, 1)
            .nOut(50)
            .activation(Activation.IDENTITY)
            .build())
        .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
            .kernelSize(2, 2)
            .stride(2, 2)
            .build())
        .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
            .nOut(500).build())
        .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS)
            .nOut(outputNum)
            .activation(Activation.SOFTMAX)
            .build())
        .setInputType(InputType.convolutionalFlat(56, 56, 3)) // InputType.convolutional for normal image
        .backprop(true).pretrain(false).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    net.setListeners(new ScoreIterationListener(10));
    log.debug("Total num of params: {}", net.numParams());

    // evaluation while training (the score should go down)
    for (int i = 0; i < nEpochs; i++) {
        net.fit(trainIter);
        log.info("Completed epoch {}", i);
        Evaluation eval = net.evaluate(testIter);
        log.info(eval.stats());
        trainIter.reset();
        testIter.reset();
    }
    ModelSerializer.writeModel(net, new File(basePath + "/Isic.model.zip"), true);
}

Output from running the model:

Odd iteration scores

Evaluation metrics

Any insight would be much appreciated.

I would suggest changing the activation functions in Layer 1 and 2 to a non-linear function. You may try with Relu and Tanh functions. You may refer to this Documentaion for a list of available activation functions.

Identity on CNNs almost never makes sense 99% of the time. Stick to RELU if you can. I would instead shift your efforts towards gradient normalization or interspersing drop out layers. Almost every time a CNN doesn't learn, it's usually due to lack of reguarlization.

Also: Never use squared loss with softmax. It never works. Stick to negative log likelihood.

I've never seen squared loss used with softmax in practice.

You can try l2 and l1 regularization (or both: This is called elastic net regularization)

It seems using an ADAM optimizer gave some promising results as well as increasing the batch size(I now have thousands of images) otherwise the net requires an absurd amount of epochs(at least 50+) in order to begin learning.

Thank you for all responses regardless.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM