簡體   English   中英

Libsvm Java培訓測試示例(也是實時的)

[英]Libsvm java training testing example(also in real time)

任何人都可以通過提供libsvm java示例進行培訓和測試來幫助我。 我是機器學習的新手,因此需要幫助。 @machine學習者提供的較早示例出現錯誤,僅給出了一個班級結果。 我不想使用weka作為先前文章中的建議。

或者您可以糾正此代碼中的錯誤,它總是預測結果中的一類。(我想執行多分類)。

此示例由“機器學習者”給出

import java.io.*;
import java.util.*;
import libsvm.*;

public class Test{
    public static void main(String[] args) throws Exception{

        // Preparing the SVM param
        svm_parameter param=new svm_parameter();
        param.svm_type=svm_parameter.C_SVC;
        param.kernel_type=svm_parameter.RBF;
        param.gamma=0.5;
        param.nu=0.5;
        param.cache_size=20000;
        param.C=1;
        param.eps=0.001;
        param.p=0.1;

        HashMap<Integer, HashMap<Integer, Double>> featuresTraining=new HashMap<Integer, HashMap<Integer, Double>>();
        HashMap<Integer, Integer> labelTraining=new HashMap<Integer, Integer>();
        HashMap<Integer, HashMap<Integer, Double>> featuresTesting=new HashMap<Integer, HashMap<Integer, Double>>();

        HashSet<Integer> features=new HashSet<Integer>();

        //Read in training data
        BufferedReader reader=null;
        try{
            reader=new BufferedReader(new FileReader("a1a.train"));
            String line=null;
            int lineNum=0;
            while((line=reader.readLine())!=null){
                featuresTraining.put(lineNum, new HashMap<Integer,Double>());
                String[] tokens=line.split("\\s+");
                int label=Integer.parseInt(tokens[0]);
                labelTraining.put(lineNum, label);
                for(int i=1;i<tokens.length;i++){
                    String[] fields=tokens[i].split(":");
                    int featureId=Integer.parseInt(fields[0]);
                    double featureValue=Double.parseDouble(fields[1]);
                    features.add(featureId);
                    featuresTraining.get(lineNum).put(featureId, featureValue);
                }
            lineNum++;
            }

            reader.close();
        }catch (Exception e){

        }

        //Read in test data
        try{
            reader=new BufferedReader(new FileReader("a1a.t"));
            String line=null;
            int lineNum=0;
            while((line=reader.readLine())!=null){

                featuresTesting.put(lineNum, new HashMap<Integer,Double>());
                String[] tokens=line.split("\\s+");
                for(int i=1; i<tokens.length;i++){
                    String[] fields=tokens[i].split(":");
                    int featureId=Integer.parseInt(fields[0]);
                    double featureValue=Double.parseDouble(fields[1]);
                    featuresTesting.get(lineNum).put(featureId, featureValue);
                }
            lineNum++;
            }
            reader.close();
        }catch (Exception e){

        }

        //Train the SVM model
        svm_problem prob=new svm_problem();
        int numTrainingInstances=featuresTraining.keySet().size();
        prob.l=numTrainingInstances;
        prob.y=new double[prob.l];
        prob.x=new svm_node[prob.l][];

        for(int i=0;i<numTrainingInstances;i++){
            HashMap<Integer,Double> tmp=featuresTraining.get(i);
            prob.x[i]=new svm_node[tmp.keySet().size()];
            int indx=0;
            for(Integer id:tmp.keySet()){
                svm_node node=new svm_node();
                node.index=id;
                node.value=tmp.get(id);
                prob.x[i][indx]=node;
                indx++;
            }

            prob.y[i]=labelTraining.get(i);
        }

        svm_model model=svm.svm_train(prob,param);

        for(Integer testInstance:featuresTesting.keySet()){
            HashMap<Integer, Double> tmp=new HashMap<Integer, Double>();
            int numFeatures=tmp.keySet().size();
            svm_node[] x=new svm_node[numFeatures];
            int featureIndx=0;
            for(Integer feature:tmp.keySet()){
                x[featureIndx]=new svm_node();
                x[featureIndx].index=feature;
                x[featureIndx].value=tmp.get(feature);
                featureIndx++;
            }

            double d=svm.svm_predict(model, x);

            System.out.println(testInstance+"\t"+d);
        }

    }
}

這是因為從未使用過featuresTesting, HashMap<Integer, Double> tmp=new HashMap<Integer, Double>(); 應該是HashMap<Integer, Double> tmp=featuresTesting.get(testInstance);

您可以使用javaML庫對數據進行分類

它是javaML的示例代碼:

   Classifier clas = new LibSVM();
        clas.buildClassifier(data);
        Dataset dataForClassification= FileHandler.loadDataset(new File(.),            0, ",");
        /* Counters for correct and wrong predictions. */
        int correct = 0, wrong = 0;
        /* Classify all instances and check with the correct class values */
        for (Instance inst : dataForClassification) {
            Object predictedClassValue = clas.classify(inst);
            Map<Object,Double> map = clas.classDistribution(inst);
            Object realClassValue = inst.classValue();
            if (predictedClassValue.equals(realClassValue))
                correct++;
            else
                wrong++;
        }

似乎您在理解自己的工作時遇到了麻煩,只是從這里到那里復制代碼。 它可以幫助您了解基本的機器學習。 例如,您可能應該閱讀LIBSVM(您使用的庫)作者的SVM分類實用指南 您得到的建議可能是您應該在線上入門機器學習課程,甚至可能更好。

讓我也給您兩個大提示,如果您獲得同一個班級的所有結果,則可以節省您的時間:

  1. 您是否要對數據進行歸一化,使所有值都線性或使用均值和標准差位於0到1(或-1到1)之間? 從您的代碼看來,您似乎並非如此。
  2. 您是否在參數中搜索良好的C值(如果是RBF內核,則為C和gamma)? 進行交叉驗證還是套用? 看起來好像不是您的代碼。

A)沒有人知道您正在引用。 如果您不希望別人理解您所指的內容,請提供鏈接。

B)您需要參加機器學習課程。 Coursera有一個免費的。 模型的輸出取決於數據本身,並且受模型參數的影響很大。 模型參數受縮放影響,通常需要搜索它們。 您的代碼沒有包含任何內容-您已經清楚地表明您是機器學習的新手。 通過獲取必要的背景知識,您將在幾分鍾之內完成數小時,數天甚至數周的工作。

C)LIBSVM for Java有很多版本,您沒有提供使用哪個版本的指示。 每個人的工作方式略有不同。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM