简体   繁体   English

dl4j lstm 不成功

[英]dl4j lstm not successful

Im trying to copy the exrcise about halfway down the page on this link: https://d2l.ai/chapter_recurrent-neural-networks/sequence.html我试图在此链接的页面中间复制练习: https://d2l.ai/chapter_recurrent-neural-networks/sequence.html

The exercise uses a sine function to create 1000 data points between -1 through 1 and use a recurrent network to approximate the function.该练习使用正弦 function 在 -1 到 1 之间创建 1000 个数据点,并使用循环网络来逼近 function。

Below is the code I used.下面是我使用的代码。 I'm going back to study more why this isn't working as it doesn't make much sense to me now when I was easily able to use a feed forward network to approximate this function.我要回去研究更多为什么这不起作用,因为现在我可以轻松地使用前馈网络来近似这个 function,这对我来说没有多大意义。

      //get data
    ArrayList<DataSet> list = new ArrayList();
   
    DataSet dss = DataSetFetch.getDataSet(Constants.DataTypes.math, "sine", 20, 500, 0, 0);

    DataSet dsMain = dss.copy();

    if (!dss.isEmpty()){
        list.add(dss);
    }

   
    if (list.isEmpty()){

        return;
    }

    //format dataset
   list = DataSetFormatter.formatReccurnent(list, 0);

    //get network
    int history = 10;
    ArrayList<LayerDescription> ldlist = new ArrayList<>();
    LayerDescription l = new LayerDescription(1,history, Activation.RELU);
    ldlist.add(l);     
    LayerDescription ll = new LayerDescription(history, 1, Activation.IDENTITY, LossFunctions.LossFunction.MSE);
    ldlist.add(ll);

    ListenerDescription ld = new ListenerDescription(20, true, false);

    MultiLayerNetwork network = Reccurent.getLstm(ldlist, 123, WeightInit.XAVIER, new RmsProp(), ld);


    //train network
    final List<DataSet> lister = list.get(0).asList();
    DataSetIterator iter = new ListDataSetIterator<>(lister, 50);
    network.fit(iter, 50);
    network.rnnClearPreviousState();


    //test network
    ArrayList<DataSet> resList = new ArrayList<>();
    DataSet result = new DataSet();
    INDArray arr = Nd4j.zeros(lister.size()+1);     
    INDArray holder;

    if (list.size() > 1){
        //test on training data
        System.err.println("oops");

    }else{
        //test on original or scaled data
        for (int i = 0; i < lister.size(); i++) {

            holder = network.rnnTimeStep(lister.get(i).getFeatures());
            arr.putScalar(i,holder.getFloat(0));

        }
    }


    //add originaldata
    resList.add(dsMain);
    //result       
    result.setFeatures(dsMain.getFeatures());
  
    result.setLabels(arr);
    resList.add(result);

    //display
    DisplayData.plot2DScatterGraph(resList);

Can you explain the code I would need for a 1 in 10 hidden and 1 out lstm network to approximate a sine function?你能解释一下我需要一个 1 分 10 隐藏和 1 出 lstm 网络来近似正弦 function 的代码吗?

Im not using any normalization as function is already -1:1 and Im using the Y input as the feature and the following Y Input as the label to train the network.我没有使用任何归一化,因为 function 已经是 -1:1 并且我使用 Y 输入作为特征,然后使用以下 Y 输入作为 label 来训练网络。

You notice i am building a class that allows for easier construction of nets and I have tried throwing many changes at the problem but I am sick of guessing.您注意到我正在构建一个 class,它可以更轻松地构建网络,并且我尝试对问题进行很多更改,但我厌倦了猜测。

Here are some examples of my results.以下是我的结果的一些示例。 Blue is data red is result蓝色是数据 红色是结果

在此处输入图像描述

在此处输入图像描述

This is one of those times were you go from wondering why was this not working to how in the hell were my original results were as good as they were.这是你 go 想知道为什么这不起作用到我原来的结果和他们一样好的时候之一。

My failing was not understanding the documentation clearly and also not understanding BPTT.我的失败是没有清楚地理解文档,也没有理解 BPTT。

With feed forward networks each iteration is stored as a row and each input as a column.对于前馈网络,每次迭代都存储为一行,每个输入存储为一列。 An example is [dataset.size, network inputs.size]一个例子是 [dataset.size, network inputs.size]

However with recurrent input its reversed with each row being a an input and each column an iteration in time necessary to activate the state of the lstm chain of events.然而,对于循环输入,它的反转是每行是一个输入,每列是一次迭代,这是激活 lstm 事件链的 state 所需的时间。 At minimum my input needed to be [0, networkinputs.size, dataset.size] But could also be [dataset.size, networkinputs.size, statelength.size]至少我的输入需要是 [0, networkinputs.size, dataset.size] 但也可以是 [dataset.size, networkinputs.size, statelength.size]

In my previous example I was training the network with data in this format [dataset.size, networkinputs.size, 1].在我之前的示例中,我使用这种格式的数据训练网络 [dataset.size, networkinputs.size, 1]。 So from my low resolution understanding the lstm network should never have worked at all but somehow produced at least something.因此,根据我对低分辨率的理解,lstm 网络根本不应该工作,但至少以某种方式产生了一些东西。

There may have also been some issue with converting the dataset to a list as I also changed how I feed the network but but I think the bulk of the issue was a data structure issue.将数据集转换为列表也可能存在一些问题,因为我也更改了为网络提供数据的方式,但我认为问题的大部分是数据结构问题。

Below are my new results以下是我的新结果不完美,但考虑到这是 5 个训练阶段,非常好

Hard to tell what is going on without seeing the full code.如果没有看到完整的代码,很难知道发生了什么。 For a start I don't see an RnnOutputLayer specified.首先,我没有看到指定的 RnnOutputLayer。 You could take a look this which shows you how to build an RNN in DL4J.你可以看看这个,它向你展示了如何在 DL4J 中构建一个 RNN。 If your RNN setup is correct this could be a tuning issue.如果您的 RNN 设置正确,这可能是一个调优问题。 You can find more on tuning here .您可以在此处找到更多关于调音的信息。 Adam is probably a better choice for an updater than RMSProp.对于更新程序,Adam 可能是比 RMSProp 更好的选择。 And tanh probably is a good choice for the activation for your output layer since it's range is (-1,1). tanh 可能是激活 output 层的不错选择,因为它的范围是 (-1,1)。 Other things to check/tweak - learning rate, number of epochs, set up of your data (like are you trying to predict to far out?).其他要检查/调整的事情——学习率、时期数、数据设置(比如你想预测很远吗?)。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM