SVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "base/Parallel.h"
00014 
00015 #include "classifier/svm/SVM.h"
00016 #include "classifier/mkl/MKL.h"
00017 
00018 #include <string.h>
00019 
00020 #ifndef WIN32
00021 #include <pthread.h>
00022 #endif
00023 
00024 #ifdef HAVE_BOOST_SERIALIZATION
00025 #include <boost/serialization/export.hpp>
00026 //BOOST_SERIALIZATION_ASSUME_ABSTRACT(CSVM);
00027 #endif //HAVE_BOOST_SERIALIZATION
00028 
00029 using namespace shogun;
00030 
00031 CSVM::CSVM(int32_t num_sv)
00032 : CKernelMachine()
00033 {
00034     set_defaults(num_sv);
00035 }
00036 
00037 CSVM::CSVM(float64_t C, CKernel* k, CLabels* lab)
00038 : CKernelMachine()
00039 {
00040     set_defaults();
00041     set_C(C,C);
00042     set_labels(lab);
00043     set_kernel(k);
00044 }
00045 
00046 CSVM::~CSVM()
00047 {
00048     SG_UNREF(mkl);
00049 }
00050 
00051 void CSVM::set_defaults(int32_t num_sv)
00052 {
00053     callback=NULL;
00054     mkl=NULL;
00055 
00056     svm_loaded=false;
00057 
00058     epsilon=1e-5;
00059     tube_epsilon=1e-2;
00060 
00061     nu=0.5;
00062     C1=1;
00063     C2=1;
00064 
00065     objective=0;
00066 
00067     qpsize=41;
00068     use_bias=true;
00069     use_shrinking=true;
00070     use_batch_computation=true;
00071     use_linadd=true;
00072 
00073     if (num_sv>0)
00074         create_new_model(num_sv);
00075 }
00076 
00077 bool CSVM::load(FILE* modelfl)
00078 {
00079     bool result=true;
00080     char char_buffer[1024];
00081     int32_t int_buffer;
00082     float64_t double_buffer;
00083     int32_t line_number=1;
00084 
00085     if (fscanf(modelfl,"%4s\n", char_buffer)==EOF)
00086     {
00087         result=false;
00088         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00089     }
00090     else
00091     {
00092         char_buffer[4]='\0';
00093         if (strcmp("%SVM", char_buffer)!=0)
00094         {
00095             result=false;
00096             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00097         }
00098         line_number++;
00099     }
00100 
00101     int_buffer=0;
00102     if (fscanf(modelfl," numsv=%d; \n", &int_buffer) != 1)
00103     {
00104         result=false;
00105         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00106     }
00107 
00108     if (!feof(modelfl))
00109         line_number++;
00110 
00111     SG_INFO( "loading %ld support vectors\n",int_buffer);
00112     create_new_model(int_buffer);
00113 
00114     if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00115     {
00116         result=false;
00117         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00118     }
00119 
00120     if (!feof(modelfl))
00121         line_number++;
00122 
00123     double_buffer=0;
00124 
00125     if (fscanf(modelfl," b=%lf; \n", &double_buffer) != 1)
00126     {
00127         result=false;
00128         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00129     }
00130 
00131     if (!feof(modelfl))
00132         line_number++;
00133 
00134     set_bias(double_buffer);
00135 
00136     if (fscanf(modelfl,"%8s\n", char_buffer) == EOF)
00137     {
00138         result=false;
00139         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00140     }
00141     else
00142     {
00143         char_buffer[9]='\0';
00144         if (strcmp("alphas=[", char_buffer)!=0)
00145         {
00146             result=false;
00147             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00148         }
00149         line_number++;
00150     }
00151 
00152     for (int32_t i=0; i<get_num_support_vectors(); i++)
00153     {
00154         double_buffer=0;
00155         int_buffer=0;
00156 
00157         if (fscanf(modelfl," \[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00158         {
00159             result=false;
00160             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00161         }
00162 
00163         if (!feof(modelfl))
00164             line_number++;
00165 
00166         set_support_vector(i, int_buffer);
00167         set_alpha(i, double_buffer);
00168     }
00169 
00170     if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00171     {
00172         result=false;
00173         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00174     }
00175     else
00176     {
00177         char_buffer[3]='\0';
00178         if (strcmp("];", char_buffer)!=0)
00179         {
00180             result=false;
00181             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00182         }
00183         line_number++;
00184     }
00185 
00186     svm_loaded=result;
00187     return result;
00188 }
00189 
00190 bool CSVM::save(FILE* modelfl)
00191 {
00192     if (!kernel)
00193         SG_ERROR("Kernel not defined!\n");
00194 
00195     SG_INFO( "Writing model file...");
00196     fprintf(modelfl,"%%SVM\n");
00197     fprintf(modelfl,"numsv=%d;\n", get_num_support_vectors());
00198     fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
00199     fprintf(modelfl,"b=%+10.16e;\n",get_bias());
00200 
00201     fprintf(modelfl, "alphas=\[\n");
00202 
00203     for(int32_t i=0; i<get_num_support_vectors(); i++)
00204         fprintf(modelfl,"\t[%+10.16e,%d];\n",
00205                 CSVM::get_alpha(i), get_support_vector(i));
00206 
00207     fprintf(modelfl, "];\n");
00208 
00209     SG_DONE();
00210     return true ;
00211 }
00212 
00213 void CSVM::set_callback_function(CMKL* m, bool (*cb)
00214         (CMKL* mkl, const float64_t* sumw, const float64_t suma))
00215 {
00216     SG_UNREF(mkl);
00217     mkl=m;
00218     SG_REF(mkl);
00219 
00220     callback=cb;
00221 }
00222 
00223 float64_t CSVM::compute_svm_dual_objective()
00224 {
00225     int32_t n=get_num_support_vectors();
00226 
00227     if (labels && kernel)
00228     {
00229         objective=0;
00230         for (int32_t i=0; i<n; i++)
00231         {
00232             int32_t ii=get_support_vector(i);
00233             objective-=get_alpha(i)*labels->get_label(ii);
00234 
00235             for (int32_t j=0; j<n; j++)
00236             {
00237                 int32_t jj=get_support_vector(j);
00238                 objective+=0.5*get_alpha(i)*get_alpha(j)*kernel->kernel(ii,jj);
00239             }
00240         }
00241     }
00242     else
00243         SG_ERROR( "cannot compute objective, labels or kernel not set\n");
00244 
00245     return objective;
00246 }
00247 
00248 float64_t CSVM::compute_svm_primal_objective()
00249 {
00250     int32_t n=get_num_support_vectors();
00251     float64_t regularizer=0;
00252     float64_t loss=0;
00253 
00254     if (labels && kernel)
00255     {
00256         for (int32_t i=0; i<n; i++)
00257         {
00258             int32_t ii=get_support_vector(i);
00259             for (int32_t j=0; j<n; j++)
00260             {
00261                 int32_t jj=get_support_vector(j);
00262                 regularizer-=0.5*get_alpha(i)*get_alpha(j)*kernel->kernel(ii,jj);
00263             }
00264 
00265             loss-=C1*CMath::max(0.0, 1.0-get_label(ii)*classify_example(ii));
00266         }
00267     }
00268     else
00269         SG_ERROR( "cannot compute objective, labels or kernel not set\n");
00270 
00271     return regularizer+loss;
00272 }
00273 
00274 
00275 float64_t* CSVM::get_linear_term_array() {
00276 
00277     float64_t* a = new float64_t[linear_term.size()];
00278     std::copy( linear_term.begin(), linear_term.end(), a);
00279 
00280     return a;
00281 
00282 }
00283 
00284 
00285 
00286 void CSVM::set_linear_term(std::vector<float64_t> lin)
00287 {
00288 
00289     if (!labels)
00290         SG_ERROR("Please assign labels first!\n");
00291 
00292     int32_t num_labels=labels->get_num_labels();
00293 
00294     if (num_labels!=(int32_t) lin.size())
00295     {
00296         SG_ERROR("Number of labels (%d) does not match number"
00297                 "of entries (%d) in linear term \n", num_labels, lin.size());
00298     }
00299 
00300     linear_term = lin;
00301 
00302 }
00303 
00304 
00305 std::vector<float64_t> CSVM::get_linear_term() {
00306 
00307     return linear_term;
00308 
00309 }

SHOGUN Machine Learning Toolbox - Documentation