[英]Why DeepLearning4J CNN is returning not probabilities but only 0s and 1s in the INDArray output
I am playing with DL4J version 1.0.0-beta3 and trying to create a convolutional neural network for recognizing 32x32 images of chess pieces.我正在使用 DL4J 版本 1.0.0-beta3 并尝试创建一个卷积神经网络来识别 32x32 棋子图像。 Here is the code I use to create and train the net:
这是我用来创建和训练网络的代码:
public class BuildNetwork1 {
public static void main(String[] args) throws Exception {
File rootDir = new File("./CNNinput/chesscom1");
File locationToSave = new File(rootDir, "trained.chesscom1.bin");
int height = 32;
int width = 32;
int channels = 1;
int rngseed = 777;
int numEpochs = 100;
File trainData = new File(rootDir, "training");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngseed)
.updater(new Adam.Builder().learningRate(0.01).build())
.activation(Activation.IDENTITY)
.weightInit(WeightInit.XAVIER)
.list()
//.layer(new ConvolutionLayer.Builder(new int[] {5, 5}, new int[] {1, 1}, new int[]{0, 0}).name("cnn1").nIn(1).nOut(64).biasInit(0).build())
//.layer(new SubsamplingLayer.Builder(new int[] {2, 2}, new int[] {2, 2}).name("maxpool1").build())
//.layer(new ConvolutionLayer.Builder(new int[] {5, 5}, new int[] {1, 1}, new int[]{0, 0}).name("cnn2").nIn(64).nOut(16).biasInit(0).build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(13)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutional(32, 32, 1))
.build();
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
network.setListeners(new ScoreIterationListener(10));
ImageLoader loader = new ImageLoader(height, width, channels);
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
for (int e = 0; e < numEpochs; e++) {
File[] labels = trainData.listFiles();
for (int i = 0; i < labels.length; i++) {
File label = labels[i];
File[] images = label.listFiles();
for (int j = 0; j < images.length; j++) {
File imageFile = images[j];
BufferedImage image = ImageIO.read(imageFile);
INDArray input = loader.asMatrix(image).reshape(1, channels, height, width);
scaler.fit(new DataSet(input, null));
scaler.transform(input);
double[][] outputArray = new double[1][13];
outputArray[0][Integer.parseInt(label.getName())] = 1d;
INDArray output = Nd4j.create(outputArray);
network.fit(input, output);
}
}
}
boolean saveUpdater = true;
ModelSerializer.writeModel(network, locationToSave, saveUpdater);
}
}
And the code I am using in order to get the result:我使用的代码是为了得到结果:
public class CalcNetworkAll {
public static void main(String[] args) throws Exception {
int height = 32;
int width = 32;
int channels = 1;
File rootDir = new File("./CNNinput/chesscom1");
File locationToLoad = new File("./CNNinput/chesscom1/trained.chesscom1.bin");
File testData = new File(rootDir, "testing");
MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(locationToLoad);
ImageLoader loader = new ImageLoader(height, width, channels);
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
File[] labels = testData.listFiles();
for (int i = 0; i < labels.length; i++) {
File label = labels[i];
File[] images = label.listFiles();
for (int j = 0; j < images.length; j++) {
File imageFile = images[j];
BufferedImage image = ImageIO.read(imageFile);
INDArray input = loader.asMatrix(image).reshape(1, channels, height, width);
scaler.fit(new DataSet(input, null));
scaler.transform(input);
INDArray output = network.output(input, false);
System.out.println(label.getName() + " => " + output);
}
}
}
}
It works well and provides the expected outcome but my problem is that the output consists of 0s and 1s only instead of probabilities:它运行良好并提供了预期的结果,但我的问题是 output 仅包含 0 和 1 而不是概率:
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000,8.1707e-37, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
0 => [[ 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
1 => [[ 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
10 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
11 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
12 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
2 => [[ 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
3 => [[ 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
4 => [[ 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
5 => [[ 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
6 => [[ 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
7 => [[ 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
8 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
9 => [[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0000, 0, 0, 0]]
Do you have an idea why it happens?你知道为什么会这样吗? Thanks a lot in advance!
提前非常感谢!
Your model is very confident in its output.您的 model 对其 output 非常有信心。 This might happen when you are showing it data that it might have seen before and when you've trained your model to fit very well on that data (often called overfitting).
当您向它展示它之前可能已经看到的数据以及当您训练 model 以非常适合该数据(通常称为过度拟合)时,可能会发生这种情况。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.