LibSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "classifier/svm/LibSVM.h"
00012 #include "lib/io.h"
00013 
00014 using namespace shogun;
00015 
00016 #ifdef HAVE_BOOST_SERIALIZATION
00017 #include <boost/serialization/export.hpp>
00018 BOOST_CLASS_EXPORT(CLibSVM);
00019 #endif //HAVE_BOOST_SERIALIZATION
00020 
00021 CLibSVM::CLibSVM(LIBSVM_SOLVER_TYPE st)
00022 : CSVM(), model(NULL), solver_type(st)
00023 {
00024 }
00025 
00026 CLibSVM::CLibSVM(float64_t C, CKernel* k, CLabels* lab)
00027 : CSVM(C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC)
00028 {
00029     problem = svm_problem();
00030 }
00031 
00032 CLibSVM::~CLibSVM()
00033 {
00034 }
00035 
00036 
00037 bool CLibSVM::train(CFeatures* data)
00038 {
00039     struct svm_node* x_space;
00040 
00041     ASSERT(labels && labels->get_num_labels());
00042     ASSERT(labels->is_two_class_labeling());
00043 
00044     if (data)
00045     {
00046         if (labels->get_num_labels() != data->get_num_vectors())
00047             SG_ERROR("Number of training vectors does not match number of labels\n");
00048         kernel->init(data, data);
00049     }
00050 
00051     problem.l=labels->get_num_labels();
00052     SG_INFO( "%d trainlabels\n", problem.l);
00053 
00054 
00055     // check length of linear term
00056     if (!linear_term.empty() && labels->get_num_labels() != (int32_t)linear_term.size())
00057         SG_ERROR("Number of training vectors does not match length of linear term\n");
00058 
00059     // set linear term
00060     if (!linear_term.empty()) {
00061 
00062         // set with linear term from base class
00063         problem.pv = get_linear_term_array();
00064 
00065     } else {
00066 
00067         // fill with minus ones
00068         problem.pv = new float64_t[problem.l];
00069 
00070         for (int i=0; i!=problem.l; i++) {
00071             problem.pv[i] = -1.0;
00072         }
00073     }
00074 
00075 
00076 
00077     problem.y=new float64_t[problem.l];
00078     problem.x=new struct svm_node*[problem.l];
00079     problem.C=new float64_t[problem.l];
00080 
00081 
00082     x_space=new struct svm_node[2*problem.l];
00083 
00084     for (int32_t i=0; i<problem.l; i++)
00085     {
00086         problem.y[i]=labels->get_label(i);
00087         problem.x[i]=&x_space[2*i];
00088         x_space[2*i].index=i;
00089         x_space[2*i+1].index=-1;
00090     }
00091 
00092     int32_t weights_label[2]={-1,+1};
00093     float64_t weights[2]={1.0,get_C2()/get_C1()};
00094 
00095     ASSERT(kernel && kernel->has_features());
00096     ASSERT(kernel->get_num_vec_lhs()==problem.l);
00097 
00098     param.svm_type=solver_type; // C SVM or NU_SVM
00099     param.kernel_type = LINEAR;
00100     param.degree = 3;
00101     param.gamma = 0;    // 1/k
00102     param.coef0 = 0;
00103     param.nu = get_nu();
00104     param.kernel=kernel;
00105     param.cache_size = kernel->get_cache_size();
00106     param.C = get_C1();
00107     param.eps = epsilon;
00108     param.p = 0.1;
00109     param.shrinking = 1;
00110     param.nr_weight = 2;
00111     param.weight_label = weights_label;
00112     param.weight = weights;
00113     param.use_bias = get_bias_enabled();
00114 
00115     const char* error_msg = svm_check_parameter(&problem, &param);
00116 
00117     if(error_msg)
00118         SG_ERROR("Error: %s\n",error_msg);
00119 
00120     model = svm_train(&problem, &param);
00121 
00122     if (model)
00123     {
00124         ASSERT(model->nr_class==2);
00125         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0]));
00126 
00127         int32_t num_sv=model->l;
00128 
00129         create_new_model(num_sv);
00130         CSVM::set_objective(model->objective);
00131 
00132         float64_t sgn=model->label[0];
00133 
00134         set_bias(-sgn*model->rho[0]);
00135 
00136         for (int32_t i=0; i<num_sv; i++)
00137         {
00138             set_support_vector(i, (model->SV[i])->index);
00139             set_alpha(i, sgn*model->sv_coef[0][i]);
00140         }
00141 
00142         delete[] problem.x;
00143         delete[] problem.y;
00144         delete[] problem.pv;
00145         delete[] problem.C;
00146 
00147 
00148         delete[] x_space;
00149 
00150         svm_destroy_model(model);
00151         model=NULL;
00152         return true;
00153     }
00154     else
00155         return false;
00156 }

SHOGUN Machine Learning Toolbox - Documentation