简体   繁体   English

如何在Matlab中训练模型,将其保存到磁盘,并加载C ++程序?

[英]How to train in Matlab a model, save it to disk, and load in C++ program?

I am using libsvm version 3.16. 我使用libsvm版本3.16。 I have done some training in Matlab, and created a model. 我在Matlab做过一些训练,并创建了一个模型。 Now I would like to save this model to disk and load this model in my C++ program. 现在我想将此模型保存到磁盘并在我的C ++程序中加载此模型。 So far I have found the following alternatives: 到目前为止,我发现了以下替代方案:

  1. This answer explains how to save a model from C++, which is based on this website. 这个答案解释了如何从C ++中保存模型 ,这是基于这个网站。 Not exactly what I need, but could be adapted. 不完全是我需要的,但可以适应。 (This requires development time). (这需要开发时间)。
  2. I could find the best training parameters (kernel,C) in Matlab and re-train everything in C++. 我可以在Matlab中找到最好的训练参数(内核,C),并用C ++重新训练一切。 (Will require doing the training in C++ each time I change a parameter. It's not scalable). (每次更改参数时都需要使用C ++进行培训。它不可扩展)。

Thus, both of these options are not satisfactory, 因此,这两种选择都不令人满意,

Does anyone have an idea? 有没有人有想法?

My solution was to retrain in C++ because I couldn't find a nice way to directly save the model. 我的解决方案是在C ++中重新训练,因为我找不到直接保存模型的好方法。 Here's my code. 这是我的代码。 You'll need to adapt it and clean it up a bit. 你需要调整它并清理一下。 The biggest change you'll have to make it not hard coding the svm_parameter values like I did. 你要做的最大的改变就是不要像我那样对svm_parameter值进行硬编码。 You'll also have to replace FilePath with std::string . 您还必须使用std::string替换FilePath I'm copying, pasting and making small edits here in SO so the formatting won't e perfect: 我在这里复制,粘贴和进行小编辑,所以格式化不完美:

Used like this: 像这样使用:

    auto targetsPath = FilePath("targets.txt");
    auto observationsPath = FilePath("observations.txt");

    auto targetsMat = MatlabMatrixFileReader::Read(targetsPath, ',');
    auto observationsMat = MatlabMatrixFileReader::Read(observationsPath, ',');
    auto v = MiscVector::ConvertVecOfVecToVec(targetsMat);
    auto model = SupportVectorRegressionModel{ observationsMat, v };

    std::vector<double> observation{ { // 32 feature observation
        0.883575729725847,0.919446119013878,0.95359403450317,
        0.968233630936732,0.91891307107125,0.887897763183844,
        0.937588566544751,0.920582702918882,0.888864454119387,
        0.890066735260163,0.87911085669864,0.903745573664995,
        0.861069296586979,0.838606194934074,0.856376230548304,
        0.863011311537075,0.807688936997926,0.740434984165146,
        0.738498042748759,0.736410940165691,0.697228384912424,
        0.608527698289016,0.632994967880269,0.66935784966765,
        0.647761430696238,0.745961037635717,0.560761134660957,
        0.545498063585615,0.590854855113663,0.486827902942118,
        0.187128866890822,- 0.0746523069562551
    } };

    double prediction = model.Predict(observation);

miscvector.h

    static vector<double> ConvertVecOfVecToVec(const vector<vector<double>> &mat)
    {
        vector<double> targetsVec;
        targetsVec.reserve(mat.size());
        for (size_t i = 0; i < mat.size(); i++)
        {
            targetsVec.push_back(mat[i][0]);
        }
        return targetsVec;
    }

libsvmtargetobjectconvertor.h

#pragma once

#include "machinelearning.h"

struct svm_node;

class LibSvmTargetObservationConvertor
{
public:
    svm_node ** LibSvmTargetObservationConvertor::ConvertObservations(const vector<MlObservation> &observations, size_t numFeatures) const
{
    svm_node **svmObservations = (svm_node **)malloc(sizeof(svm_node *) * observations.size());
    for (size_t rowI = 0; rowI < observations.size(); rowI++)
    {
        svm_node *row = (svm_node *)malloc(sizeof(svm_node) * numFeatures);
        for (size_t colI = 0; colI < numFeatures; colI++)
        {
            row[colI].index = colI;
            row[colI].value = observations[rowI][colI];
        }
        row[numFeatures].index = -1; // apparently needed
        svmObservations[rowI] = row;
    }
    return svmObservations;
}

svm_node* LibSvmTargetObservationConvertor::ConvertMatToSvmNode(const MlObservation &observation) const
{
    size_t numFeatures = observation.size();
    svm_node *obsNode = (svm_node *)malloc(sizeof(svm_node) * numFeatures);
    for (size_t rowI = 0; rowI < numFeatures; rowI++)
    {
        obsNode[rowI].index = rowI;
        obsNode[rowI].value = observation[rowI];
    }
    obsNode[numFeatures].index = -1; // apparently needed
    return obsNode;
}
};

machinelearning.h

#pragma once

#include <vector>
using std::vector;

using MlObservation = vector<double>;
using MlTarget = double;

//machinelearningmodel.h
#pragma once

#include <vector>
#include "machinelearning.h"
class MachineLearningModel
{
public:
    virtual ~MachineLearningModel() {}
    virtual double Predict(const MlObservation &observation) const = 0;
};

matlabmatrixfilereader.h

#pragma once

#include <vector>
using std::vector;

class FilePath;
// Matrix created with command:
// dlmwrite('my_matrix.txt', somematrix, 'delimiter', ',', 'precision', 15);
// In these files, each row is a matrix row. Commas separate elements on a row.
// There is no space at the end of a row. There is a blank line at the bottom of the file.
// File format:
// 0.4,0.7,0.8
// 0.9,0.3,0.5
// etc.
static class MatlabMatrixFileReader
{
public:
    static vector<vector<double>> Read(const FilePath &asciiFilePath, char delimiter)
{

    vector<vector<double>> values;
    vector<double> valueline;
    std::ifstream fin(asciiFilePath.Path());
    string item, line;
    while (getline(fin, line))
    {
        std::istringstream in(line);

        while (getline(in, item, delimiter))
        {
            valueline.push_back(atof(item.c_str()));
        }           
        values.push_back(valueline);
        valueline.clear();
    }
    fin.close();
    return values;
}

};

supportvectorregressionmodel.h

#pragma once

#include <vector>
using std::vector;
#include "machinelearningmodel.h"

#include "svm.h" // libsvm

class FilePath;

class SupportVectorRegressionModel : public MachineLearningModel
{
public:
    SupportVectorRegressionModel::~SupportVectorRegressionModel()
{
    svm_free_model_content(model_);
    svm_destroy_param(&param_);
    svm_free_and_destroy_model(&model_);
}

SupportVectorRegressionModel::SupportVectorRegressionModel(const vector<MlObservation>& observations, const vector<MlTarget>& targets)
{
    // assumes all observations have same number of features
    size_t numFeatures = observations[0].size();

    //setup targets
    //auto v = ConvertVecOfVecToVec(targetsMat);
    double *targetsPtr = const_cast<double *>(&targets[0]); // why aren't the targets const?

    LibSvmTargetObservationConvertor conv;
    svm_node **observationsPtr = conv.ConvertObservations(observations, numFeatures);

    // setup observations
    //svm_node **observations = BuildObservations(observationsMat, numFeatures);

    // setup problem
    svm_problem problem;
    problem.l = targets.size();
    problem.y = targetsPtr;
    problem.x = observationsPtr;

    // specific to out training sets
    // TODO:    This is hard coded. 
    //          Bust out these values for use in constructor
    param_.C = 0.4;                 // cost
    param_.svm_type = 4;            // SVR
    param_.kernel_type = 2;         // radial
    param_.nu = 0.6;                // SVR nu
                                    // These values are the defaults used in the Matlab version
                                    // as found in svm_model_matlab.c
    param_.gamma = 1.0 / (double)numFeatures;
    param_.coef0 = 0;
    param_.cache_size = 100;        // in MB
    param_.shrinking = 1;
    param_.probability = 0;
    param_.degree = 3;
    param_.eps = 1e-3;
    param_.p = 0.1;
    param_.shrinking = 1;
    param_.probability = 0;
    param_.nr_weight = 0;
    param_.weight_label = NULL;
    param_.weight = NULL;

    // suppress command line output
    svm_set_print_string_function([](auto c) {});

    model_ = svm_train(&problem, &param_);
}

double SupportVectorRegressionModel::Predict(const vector<double>& observation) const
{
    LibSvmTargetObservationConvertor conv;
    svm_node *obsNode = conv.ConvertMatToSvmNode(observation);
    double prediction = svm_predict(model_, obsNode);
    return prediction;
}

SupportVectorRegressionModel::SupportVectorRegressionModel(const FilePath & modelFile)
{
    model_ = svm_load_model(modelFile.Path().c_str());
}
private:
    svm_model *model_;
    svm_parameter param_;
};

Option 1 is actually pretty reasonable. 选项1实际上非常合理。 If you save the model in libsvm's C format through matlab, then it is straightforward to work with the model in C/C++ using functions provided by libsvm. 如果通过matlab以libsvm的C格式保存模型,那么使用libsvm提供的函数可以直接在C / C ++中使用模型。 Trying to work with matlab-formatted data in C++ will probably be much more difficult. 尝试在C ++中使用matlab格式的数据可能会困难得多。

The main function in "svm-predict.c" (located in the root directory of the libsvm package) probably has most of what you need: “svm-predict.c”中的main功能(位于libsvm包的根目录中)可能具有您需要的大部分功能:

if((model=svm_load_model(argv[i+1]))==0)
{
    fprintf(stderr,"can't open model file %s\n",argv[i+1]);
    exit(1);
}

To predict a label for example x using the model, you can run 要使用模型预测例如x的标签,您可以运行

int predict_label = svm_predict(model,x);

The trickiest part of this will be to transfer your data into the libsvm format (unless your data is in the libsvm text file format, in which case you can just use the predict function in "svm-predict.c"). 最棘手的部分是将数据转换为libsvm格式(除非您的数据采用libsvm文本文件格式,在这种情况下,您只需使用“svm-predict.c”中的predict函数)。

A libsvm vector, x , is an array of struct svm_node that represents a sparse array of data. libsvm向量xstruct svm_node的数组,表示稀疏的数据数组。 Each svm_node has an index and a value, and the vector must be terminated by an index that is set to -1. 每个svm_node都有一个索引和一个值,并且该向量必须由一个设置为-1的索引终止。 For instance, to encode the vector [0,1,0,5] , you could do the following: 例如,要对矢量[0,1,0,5]进行编码,您可以执行以下操作:

struct svm_node *x = (struct svm_node *) malloc(3*sizeof(struct svm_node));
x[0].index=2; //NOTE: libsvm indices start at 1
x[0].value=1.0;
x[1].index=4;
x[1].value=5.0;
x[2].index=-1;

For SVM types other than the classifier (C_SVC), look at the predict function in "svm-predict.c". 对于分类器(C_SVC)以外的SVM类型,请查看“svm-predict.c”中的predict函数。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM