簡體   English   中英

DL4J LSTM - 矛盾錯誤

[英]DL4J LSTM - Contradictory Errors

我試圖在 Java 中使用 Deeplearning4J 創建一個簡單的 LSTM,具有 2 個輸入特征和 1 的時間序列長度。但是,我在調用 predict() 時遇到了有關輸入維數的錯誤。

import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class LSTMRegression {
    public static final int inputSize = 2,
                            lstmLayerSize = 4,
                            outputSize = 1;
    
    public static final double learningRate = 0.0001;

    public static void main(String[] args) {
        int miniBatchSize = 29;
        
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(new Adam(learningRate))
                .list()
                .layer(0, new LSTM.Builder().nIn(inputSize).nOut(lstmLayerSize)
                        .weightInit(WeightInit.XAVIER)
                        .activation(Activation.IDENTITY).build())
                .layer(1, new LSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)
                        .weightInit(WeightInit.XAVIER)
                        .activation(Activation.SIGMOID).build())
                .layer(2, new LSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize)
                        .weightInit(WeightInit.XAVIER)
                        .activation(Activation.SIGMOID).build())
                .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
                        .weightInit(WeightInit.XAVIER)
                        .activation(Activation.IDENTITY)
                        .nIn(lstmLayerSize).nOut(outputSize).build())
                
                .backpropType(BackpropType.TruncatedBPTT)
                .tBPTTForwardLength(miniBatchSize)
                .tBPTTBackwardLength(miniBatchSize)
                .build();
        
        var network = new MultiLayerNetwork(conf);
        
        network.init();
        network.fit(getTrain());
        
        System.out.println(network.predict(getTest()));
    }
    
    public static DataSet getTest() {
        INDArray input = Nd4j.zeros(29, 2, 1);

        INDArray labels = Nd4j.zeros(29, 1);
        
        return new DataSet(input, labels);
    }
    
    public static DataSet getTrain() {
        INDArray input = Nd4j.zeros(29, 2, 1);
        INDArray labels = Nd4j.zeros(29, 1);
        
        return new DataSet(input, labels);
    }
}

運行時出現如下錯誤:

22:38:28.803 [main] INFO  o.d.nn.multilayer.MultiLayerNetwork - Starting MultiLayerNetwork with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE]
22:38:29.755 [main] WARN  o.d.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [29, 2, 1] and labels with shape [29, 1]
Exception in thread "main" java.lang.IllegalStateException: predict(INDArray) method can only be used on rank 2 output - got array with rank 3
    at org.nd4j.common.base.Preconditions.throwStateEx(Preconditions.java:639)
    at org.nd4j.common.base.Preconditions.checkState(Preconditions.java:274)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.predict(MultiLayerNetwork.java:2275)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.predict(MultiLayerNetwork.java:2286)
    at LSTMRegression.main(LSTMRegression.java:78)

我覺得這很奇怪,但我還是嘗試重塑它:

    public static DataSet getTest() {
        INDArray input = Nd4j.zeros(29, 2, 1).reshape(29, 2);

        INDArray labels = Nd4j.zeros(29, 1);
        
        return new DataSet(input, labels);
    }

...導致相反的問題:

22:45:28.232 [main] WARN  o.d.nn.multilayer.MultiLayerNetwork - Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got [29, 2, 1] and labels with shape [29, 1]
Exception in thread "main" java.lang.IllegalStateException: 3D input expected to RNN layer expected, got 2
    at org.nd4j.common.base.Preconditions.throwStateEx(Preconditions.java:639)
    at org.nd4j.common.base.Preconditions.checkState(Preconditions.java:265)
    at org.deeplearning4j.nn.layers.recurrent.LSTM.activateHelper(LSTM.java:121)
    at org.deeplearning4j.nn.layers.recurrent.LSTM.activate(LSTM.java:110)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.outputOfLayerDetached(MultiLayerNetwork.java:1349)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:2467)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:2430)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:2421)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:2408)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.predict(MultiLayerNetwork.java:2270)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.predict(MultiLayerNetwork.java:2286)
    at LSTMRegression.main(LSTMRegression.java:78)

我到底做錯了什么?

編輯:我顯然使用零來使代碼更容易閱讀。 這是我的訓練和測試數據實際上以多維雙數組形式顯示的樣子:

public static DataSet getData() {
        double[][][] inputArray = {
            {{18.7}, {181}},
            {{17.4}, {186}},
            {{18}, {195}},
            {{19.3}, {193}},
            {{20.6}, {190}},
            {{17.8}, {181}},
            {{19.6}, {195}},
            {{18.1}, {193}},
            {{20.2}, {190}},
            {{17.1}, {186}},
            {{17.3}, {180}},
            ...
        }
       double[][] outputArray = {
                {3750},
                {3800},
                {3250},
                {3450},
                {3650},
                {3625},
                {4675},
                {3475},
                {4250},
                {3300},
                {3700},
                {3200},
                {3800},
                {4400},
                {3700},
                {3450},
                {4500},
                ...
        };
        INDArray input = Nd4j.create(inputArray);
        INDArray labels = Nd4j.create(outputArray);
        
        return new DataSet(input, labels);
}

...以及我的測試數據(更新為僅包含輸入):

public static INDArray getTest() {
        double[][][] test = new double[][][]{
            {{20}, {203}},
            {{16}, {183}},
            {{20}, {190}},
            {{18.6}, {193}},
            {{18.9}, {184}},
            {{17.2}, {199}},
            {{20}, {190}},
            {{17}, {181}},
            {{19}, {197}},
            {{16.5}, {198}},
            {{20.3}, {191}},
            {{17.7}, {193}},
            {{19.5}, {197}},
            {{20.7}, {191}},
            {{18.3}, {196}},
            {{17}, {188}},
            {{20.5}, {199}},
            {{17}, {189}},
            {{18.6}, {189}},
            {{17.2}, {187}},
            {{19.8}, {198}},
            {{17}, {176}},
            {{18.5}, {202}},
            {{15.9}, {186}},
            {{19}, {199}},
            {{17.6}, {191}},
            {{18.3}, {195}},
            {{17.1}, {191}},
            {{18}, {210}}
        };
        
        INDArray input = Nd4j.create(test);
        
        return input;
    }

你在這里遇到了幾個問題。 如果您閱讀predict的文檔,它會告訴您:

僅可用於 classification.networks 與 OutputLayer 結合使用。 不能與 RnnOutputLayer、CnnLossLayer 或用於回歸的網絡一起使用。

因此,錯誤消息告訴您它僅適用於rank 2 output

在您嘗試的解決方案中,您嘗試重塑輸入,而 the.network 抱怨它沒有得到它期望的輸入。

您想要使用rnnTimeStep (對於單步執行)或output (對於整個序列)來獲取未處理的 output,然后相應地應用argMax

rnnTimeStep()的 output 只是 output 的一部分,因此為了獲得與predict相同的 output,您應該能夠在其上使用output.argMax(1).toIntVector()

output()的 output 將是一個二維矩陣,因此您需要指定正確的軸。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM