libsvm java实现

我正在尝试使用libsvm的java绑定:

http://www.csie.ntu.edu.tw/~cjlin/libsvm/ 

我已经实现了一个“平凡”的例子,它可以在y中轻松地线性分离。 数据定义为:

 double[][] train = new double[1000][]; double[][] test = new double[10][]; for (int i = 0; i  (train.length/2)){ // 50% positive double[] vals = {1,0,i+i}; train[i] = vals; } else { double[] vals = {0,0,iii-2}; // 50% negative train[i] = vals; } } 

第一个“特征”是类,并且训练集类似地定义。

训练模型:

 private svm_model svmTrain() { svm_problem prob = new svm_problem(); int dataCount = train.length; prob.y = new double[dataCount]; prob.l = dataCount; prob.x = new svm_node[dataCount][]; for (int i = 0; i < dataCount; i++){ double[] features = train[i]; prob.x[i] = new svm_node[features.length-1]; for (int j = 1; j < features.length; j++){ svm_node node = new svm_node(); node.index = j; node.value = features[j]; prob.x[i][j-1] = node; } prob.y[i] = features[0]; } svm_parameter param = new svm_parameter(); param.probability = 1; param.gamma = 0.5; param.nu = 0.5; param.C = 1; param.svm_type = svm_parameter.C_SVC; param.kernel_type = svm_parameter.LINEAR; param.cache_size = 20000; param.eps = 0.001; svm_model model = svm.svm_train(prob, param); return model; } 

然后评估我使用的模型:

 public int evaluate(double[] features) { svm_node node = new svm_node(); for (int i = 1; i < features.length; i++){ node.index = i; node.value = features[i]; } svm_node[] nodes = new svm_node[1]; nodes[0] = node; int totalClasses = 2; int[] labels = new int[totalClasses]; svm.svm_get_labels(_model,labels); double[] prob_estimates = new double[totalClasses]; double v = svm.svm_predict_probability(_model, nodes, prob_estimates); for (int i = 0; i < totalClasses; i++){ System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")"); } System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")"); return (int)v; } 

传递的数组是测试集中的一个点。

结果总是返回0级。确切的结果是:

 (0:0.9882998314585194)(1:0.011700168541480586)(Actual:0.0 Prediction:0.0) (0:0.9883952943701599)(1:0.011604705629839989)(Actual:0.0 Prediction:0.0) (0:0.9884899803606306)(1:0.011510019639369528)(Actual:0.0 Prediction:0.0) (0:0.9885838957058696)(1:0.011416104294130458)(Actual:0.0 Prediction:0.0) (0:0.9886770466322342)(1:0.011322953367765776)(Actual:0.0 Prediction:0.0) (0:0.9870913229268679)(1:0.012908677073132284)(Actual:1.0 Prediction:0.0) (0:0.9868781382588805)(1:0.013121861741119505)(Actual:1.0 Prediction:0.0) (0:0.986661444476744)(1:0.013338555523255982)(Actual:1.0 Prediction:0.0) (0:0.9864411843906802)(1:0.013558815609319848)(Actual:1.0 Prediction:0.0) (0:0.9862172999068877)(1:0.013782700093112332)(Actual:1.0 Prediction:0.0) 

有人可以解释为什么这个分类器不工作? 我有一个搞砸的步骤,还是我失踪的一步?

谢谢

在我看来,你的评估方法是错误的。 应该是这样的:

 public double evaluate(double[] features, svm_model model) { svm_node[] nodes = new svm_node[features.length-1]; for (int i = 1; i < features.length; i++) { svm_node node = new svm_node(); node.index = i; node.value = features[i]; nodes[i-1] = node; } int totalClasses = 2; int[] labels = new int[totalClasses]; svm.svm_get_labels(model,labels); double[] prob_estimates = new double[totalClasses]; double v = svm.svm_predict_probability(model, nodes, prob_estimates); for (int i = 0; i < totalClasses; i++){ System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")"); } System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")"); return v; } 

以下是我使用以下R代码中的数据进行测试的上述示例的返工: http : //cbio.ensmp.fr/~jvert/svn/tutorials/practical/svmbasic/svmbasic_notes.pdf

 import libsvm.*; public class libsvmTest { public static void main(String [] args) { double[][] xtrain = ... double[][] xtest = ... double[][] ytrain = ... double[][] ytest = ... svm_model m = svmTrain(xtrain,ytrain); double[] ypred = svmPredict(xtest, m); for (int i = 0; i < xtest.length; i++){ System.out.println("(Actual:" + ytest[i][0] + " Prediction:" + ypred[i] + ")"); } } static svm_model svmTrain(double[][] xtrain, double[][] ytrain) { svm_problem prob = new svm_problem(); int recordCount = xtrain.length; int featureCount = xtrain[0].length; prob.y = new double[recordCount]; prob.l = recordCount; prob.x = new svm_node[recordCount][featureCount]; for (int i = 0; i < recordCount; i++){ double[] features = xtrain[i]; prob.x[i] = new svm_node[features.length]; for (int j = 0; j < features.length; j++){ svm_node node = new svm_node(); node.index = j; node.value = features[j]; prob.x[i][j] = node; } prob.y[i] = ytrain[i][0]; } svm_parameter param = new svm_parameter(); param.probability = 1; param.gamma = 0.5; param.nu = 0.5; param.C = 100; param.svm_type = svm_parameter.C_SVC; param.kernel_type = svm_parameter.LINEAR; param.cache_size = 20000; param.eps = 0.001; svm_model model = svm.svm_train(prob, param); return model; } static double[] svmPredict(double[][] xtest, svm_model model) { double[] yPred = new double[xtest.length]; for(int k = 0; k < xtest.length; k++){ double[] fVector = xtest[k]; svm_node[] nodes = new svm_node[fVector.length]; for (int i = 0; i < fVector.length; i++) { svm_node node = new svm_node(); node.index = i; node.value = fVector[i]; nodes[i] = node; } int totalClasses = 2; int[] labels = new int[totalClasses]; svm.svm_get_labels(model,labels); double[] prob_estimates = new double[totalClasses]; yPred[k] = svm.svm_predict_probability(model, nodes, prob_estimates); } return yPred; } } 

这是输出:

 (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:1.0 Prediction:1.0) (Actual:-1.0 Prediction:-1.0) (Actual:-1.0 Prediction:-1.0) (Actual:-1.0 Prediction:-1.0) (Actual:-1.0 Prediction:-1.0) (Actual:-1.0 Prediction:-1.0) (Actual:-1.0 Prediction:-1.0) (Actual:-1.0 Prediction:-1.0) (Actual:-1.0 Prediction:-1.0) (Actual:-1.0 Prediction:1.0) (Actual:-1.0 Prediction:-1.0) (Actual:-1.0 Prediction:-1.0) (Actual:-1.0 Prediction:-1.0) (Actual:-1.0 Prediction:-1.0) (Actual:-1.0 Prediction:-1.0) (Actual:-1.0 Prediction:-1.0) 

我做了一个稍微重构的LibSVM java实现版本,您可能会发现它更容易使用: https : //github.com/syeedibnfaiz/libsvm-java-kernel 。 看看Demo.java类,看看如何使用它。