简体   繁体   中英

Deeplearning4j Splitting datasets for test and train

Deeplearning4j has functions to support splitting datasets into test and train, as well as mechanisms for shuffling datasets, however as far as I can tell either they don't work or I'm doing something wrong.

Example:

    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    DataSet next = iter.next();
    // next.shuffle();
    SplitTestAndTrain testAndTrain = next.splitTestAndTrain(120, new Random(seed));
    DataSet train = testAndTrain.getTrain();
    DataSet test = testAndTrain.getTest();

    for (int i = 0; i < 30; i++) {
        String features = test.getFeatures().getRow(i).toString();
        String actual = test.getLabels().getRow(i).toString().trim();
        log.info("features " + features + " -> " + actual );
    }

Results in the last 30 rows of the input dataset returned, the Random(seed) parameter to splitTestAndTrain seems to have been ignored completely.

If instead of passing the random seed to splitTestAndTrain I instead uncomment the next.shuffle() line, then oddly the 3rd and 4th features get shuffled while maintaining the existing order for the 1st and 2nd features as well as the test label, which is even worse than not sorting the input at all.

So... the question is, am I using it wrong, or is Deeplearning4j just inherently broken?

Bonus question: if Deeplearning4j is broken for something as simple as generating test and sample datasets, should it be trusted with anything at all? Or would I be better off using a different library?

Deeplearning4j assumes that datasets are minibatches, eg: they are not all in memory. This contradicts the python world which might optimize a bit more for smaller datasets and ease of use.

This only works for toy problems and does not scale well to real problems. In lieu of this we optimize for the datasetiterator interface for local scenarios (note that this will be different for distributed systems like spark).

This means we rely on the datasets either being split before hand using datavec to parse the dataset (hint: do not write your own iterator: use ours and use datavec for custom parsing) or allowing the use of a datasetiterator splitter: https://deeplearning4j.org/doc/org/deeplearning4j/datasets/iterator/DataSetIteratorSplitter.html for train test split.

The dataset split train test class will only work if the dataset is already all in memory but may not make sense for most semi realistic problems (eg: getting beyond xor or mnist)

I recommend running your ETL step once rather than every time. Preshuffle your dataset in to pre sliced batches. One way to do this is with a combination of: https://github.com/deeplearning4j/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/BalanceMinibatchesTest.java#L40 and: https://nd4j.org/doc/org/nd4j/linalg/dataset/ExistingMiniBatchDataSetIterator.html

Another reason to do this is reproducibility. If you want to do something like shuffle your iterator each epoch, you could try writing some code based on a combination of the above. Either way, I would try to handle your ETL and pre create the vectors before you do training. Otherwise, you're spending a lot of time on data loading on larger datasets.

As far as I can tell, deeplearning4j is simply broken. Ultimately I created my own implementation of splitTestandTrain.

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import java.util.Random;
import org.nd4j.linalg.factory.Nd4j;

public class TestTrain {  
    protected DataSet test;
    protected DataSet train;

    public TestTrain(DataSet input, int splitSize, Random rng) {
        int inTest = 0;
        int inTrain = 0;
        int testSize = input.numExamples() - splitSize;

        INDArray train_features = Nd4j.create(splitSize, input.getFeatures().columns());
        INDArray train_outcomes = Nd4j.create(splitSize, input.numOutcomes());
        INDArray test_features  = Nd4j.create(testSize, input.getFeatures().columns());
        INDArray test_outcomes  = Nd4j.create(testSize, input.numOutcomes());

        for (int i = 0; i < input.numExamples(); i++) {
            DataSet D = input.get(i);
            if (rng.nextDouble() < (splitSize-inTrain)/(double)(input.numExamples()-i)) {
                train_features.putRow(inTrain, D.getFeatures());
                train_outcomes.putRow(inTrain, D.getLabels());
                inTrain += 1;
            } else {
                test_features.putRow(inTest, D.getFeatures());
                test_outcomes.putRow(inTest, D.getLabels());
                inTest += 1;
            }
        }

        train = new DataSet(train_features, train_outcomes);
        test  = new DataSet(test_features, test_outcomes);
    }

    public DataSet getTrain() {
        return train;
    }

    public DataSet getTest() {
        return test;
    }
}

This works, but it does not give me confidence in the library. Still happy if someone else can provide a better answer, but for now this will have to do.

As this question is outdated, for people that might find this, you can see some examples on GitHub , the split can be done in a simple way:

DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training

DataSet trainingData = testAndTrain.getTrain();
DataSet testData = testAndTrain.getTest();

Were first you create the iterator, iterate over all the data, shuffle it and the split between the test and train.

This is taken from this example

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM