简体   繁体   English

保存经过训练的神经网

[英]Save trained Neural network

I need help with saving my neural network. 我需要帮助来保存我的神经网络。

I'll explain...i programmed multi-layer network in C# ..the part of application is for training and the other part is for testing neural network. 我将解释...我用C#编写了多层网络 。应用程序的一部分用于训练,另一部分用于测试神经网络。 Everything works exactly as it should. 一切都完全正常。 When i want to train my network i load a set of data from a file. 当我想训练我的网络时,我从文件中加载一组数据。 When the training is over i test it on a smaller sample of data and it gives me correct output. 当训练结束时,我在较小的数据样本上进行测试,它给出了正确的输出。 But now i would like to be able to train my network and save it, so that i can load it again and use it for further testing. 但现在我希望能够训练我的网络并保存它,以便我可以再次加载它并将其用于进一步测试。

I will assume you have your machine learning class called Bayes (or whatever). 我假设你有一个名为Bayes (或其他)的机器学习课程。 Typically you would mark this as [Serializable] 通常你会将其标记为[Serializable]

using System.IO;
[Serializable]
public class NaiveBayes
{
    ...
}

In this class you could then have a method to do your saving 在这个课程中,你可以有一个方法来保存

public void Save(Stream stream)
{
    YourBinaryFormatter b = new YourBinaryFormatter();
    b.Serialize(stream, this);
}

YourBinarySerializer here is just some serializer of your choice, you can use another serializer if you wish. 这里的YourBinarySerializer只是你选择的一些序列化器,你可以根据需要使用另一个序列化器。 Reading these files is the reverse and is equally as easy. 阅读这些文件是相反的,同样容易。

Yes, you could use Serializable as mentioned. 是的,你可以像上面提到的那样使用Serializable。 But in robotics, there's often a need for language independent (ie easily parsable) knowledge storing which is why I'm adding this answer. 但是在机器人技术中,通常需要独立于语言(即易于分析)的知识存储,这就是我添加这个答案的原因。

So, how do you save the current state of any data type? 那么,如何保存任何数据类型的当前状态?
1) Write the type, its state and a possible descriptor 2) Read it. 1)写出类型,状态和可能的描述符2)读取它。

For an integer int a = 3 , you could write a file with the following content: 对于整数int a = 3 ,您可以编写包含以下内容的文件:

integer
a
3

A neural network is an abstract data type, just like an integer. 神经网络是一种抽象数据类型,就像整数一样。 It is defined by a topology and the final weights after training. 它由拓扑和训练后的最终权重定义。 Let's say you have a MLP with in=3, hid=6, out=2, then you could write a file with the following contents: 假设您有一个输入= 3,hid = 6,输出= 2的MLP,那么您可以编写一个包含以下内容的文件:

3-6-2 // topology
test1 // name of neural network, could also be in filename (or timestamp)
weight matrix [in->hid]
weight matrix [hid->out]

while you would of course write the actual weights instead of "weight matrix". 而你当然会写实际的权重而不是“权重矩阵”。 You can fetch the topology at the initialization part of your program or together with the weights, which is at the end of the training stage. 您可以在程序的初始化部分获取拓扑,也可以在训练阶段结束时获取权重。

If you want to reconstruct your network, simply parse the written file and use everything you've read to initialize the network as before but now skip training. 如果要重建网络,只需解析写入的文件并使用您读过的所有内容来像以前一样初始化网络,但现在跳过培训。 You should be able to write files on your robot. 您应该能够在机器人上编写文件。 If you're unable to do so, send the information over wifi to your local computer and write it there. 如果您无法这样做,请通过wifi将信息发送到您的本地计算机并在那里写。

You can check my implementation of multi-layer network in C# here . 您可以在此处查看我在C#中实现的多层网络。

You need to serialize the network structure, weights and biases. 您需要序列化网络结构,权重和偏差。 Ann library has a helper methods for that: Ann库有一个帮助方法:

Step 1. Define layer configuration: 步骤1.定义图层配置:

var layerConfig = new LayerConfiguration()
  .AddInputLayer(2)
  .AddHiddenLayer(5)
  .AddHiddenLayer(5)
  .AddOutputLayer(1);

Step 2. Train Model: 第2步。火车模型:

model.TrainModel(new List<double> { 0.25, 0.50 }, new List<double> { 1 });
model.TrainModel(new List<double> { 0.75, 0.15 }, new List<double> { 0 });
model.TrainModel(new List<double> { 0.60, 0.40 }, new List<double> { 1 });
...

Step 3. Save trained Model to JSON file: 步骤3.将训练有素的模型保存到JSON文件:

model.SaveModelToJson("model.json");

You can instantiate new Network object and use previously trained model: 您可以实例化新的Network对象并使用以前训练过的模型:

var model2 = new Network("model.json");
List<double> output = model2.UseModel(new List<double> { 0.35, 0.45 });

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM