MultiClassSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 #include "lib/io.h"
00013 #include "classifier/svm/MultiClassSVM.h"
00014 
00015 using namespace shogun;
00016 
00017 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type)
00018 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00019 {
00020 }
00021 
00022 CMultiClassSVM::CMultiClassSVM(
00023     EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab)
00024 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL)
00025 {
00026 }
00027 
00028 CMultiClassSVM::~CMultiClassSVM()
00029 {
00030     cleanup();
00031 }
00032 
00033 void CMultiClassSVM::cleanup()
00034 {
00035     for (int32_t i=0; i<m_num_svms; i++)
00036         SG_UNREF(m_svms[i]);
00037 
00038     delete[] m_svms;
00039     m_num_svms=0;
00040     m_svms=NULL;
00041 }
00042 
00043 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes)
00044 {
00045     if (num_classes>0)
00046     {
00047         cleanup();
00048 
00049         m_num_classes=num_classes;
00050 
00051         if (multiclass_type==ONE_VS_REST)
00052             m_num_svms=num_classes;
00053         else if (multiclass_type==ONE_VS_ONE)
00054             m_num_svms=num_classes*(num_classes-1)/2;
00055         else
00056             SG_ERROR("unknown multiclass type\n");
00057 
00058         m_svms=new CSVM*[m_num_svms];
00059         if (m_svms)
00060         {
00061             memset(m_svms,0, m_num_svms*sizeof(CSVM*));
00062             return true;
00063         }
00064     }
00065     return false;
00066 }
00067 
00068 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm)
00069 {
00070     if (m_num_svms>0 && m_num_svms>num && num>=0 && svm)
00071     {
00072         SG_REF(svm);
00073         m_svms[num]=svm;
00074         return true;
00075     }
00076     return false;
00077 }
00078 
00079 CLabels* CMultiClassSVM::classify()
00080 {
00081     if (multiclass_type==ONE_VS_REST)
00082         return classify_one_vs_rest();
00083     else if (multiclass_type==ONE_VS_ONE)
00084         return classify_one_vs_one();
00085     else
00086         SG_ERROR("unknown multiclass type\n");
00087 
00088     return NULL;
00089 }
00090 
00091 CLabels* CMultiClassSVM::classify_one_vs_one()
00092 {
00093     ASSERT(m_num_svms>0);
00094     ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00095     CLabels* result=NULL;
00096 
00097     if (!kernel)
00098     {
00099         SG_ERROR( "SVM can not proceed without kernel!\n");
00100         return false ;
00101     }
00102 
00103     if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00104     {
00105         int32_t num_vectors=kernel->get_num_vec_rhs();
00106 
00107         result=new CLabels(num_vectors);
00108         SG_REF(result);
00109 
00110         ASSERT(num_vectors==result->get_num_labels());
00111         CLabels** outputs=new CLabels*[m_num_svms];
00112 
00113         for (int32_t i=0; i<m_num_svms; i++)
00114         {
00115             SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]);
00116             ASSERT(m_svms[i]);
00117             m_svms[i]->set_kernel(kernel);
00118             m_svms[i]->set_labels(labels);
00119             outputs[i]=m_svms[i]->classify();
00120         }
00121 
00122         int32_t* votes=new int32_t[m_num_classes];
00123         for (int32_t v=0; v<num_vectors; v++)
00124         {
00125             int32_t s=0;
00126             memset(votes, 0, sizeof(int32_t)*m_num_classes);
00127 
00128             for (int32_t i=0; i<m_num_classes; i++)
00129             {
00130                 for (int32_t j=i+1; j<m_num_classes; j++)
00131                 {
00132                     if (outputs[s++]->get_label(v)>0)
00133                         votes[i]++;
00134                     else
00135                         votes[j]++;
00136                 }
00137             }
00138 
00139             int32_t winner=0;
00140             int32_t max_votes=votes[0];
00141 
00142             for (int32_t i=1; i<m_num_classes; i++)
00143             {
00144                 if (votes[i]>max_votes)
00145                 {
00146                     max_votes=votes[i];
00147                     winner=i;
00148                 }
00149             }
00150 
00151             result->set_label(v, winner);
00152         }
00153 
00154         delete[] votes;
00155 
00156         for (int32_t i=0; i<m_num_svms; i++)
00157             SG_UNREF(outputs[i]);
00158         delete[] outputs;
00159     }
00160 
00161     return result;
00162 }
00163 
00164 CLabels* CMultiClassSVM::classify_one_vs_rest()
00165 {
00166     ASSERT(m_num_svms>0);
00167     CLabels* result=NULL;
00168 
00169     if (!kernel)
00170     {
00171         SG_ERROR( "SVM can not proceed without kernel!\n");
00172         return false ;
00173     }
00174 
00175     if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00176     {
00177         int32_t num_vectors=kernel->get_num_vec_rhs();
00178 
00179         result=new CLabels(num_vectors);
00180         SG_REF(result);
00181 
00182         ASSERT(num_vectors==result->get_num_labels());
00183         CLabels** outputs=new CLabels*[m_num_svms];
00184 
00185         for (int32_t i=0; i<m_num_svms; i++)
00186         {
00187             ASSERT(m_svms[i]);
00188             m_svms[i]->set_kernel(kernel);
00189             m_svms[i]->set_labels(labels);
00190             outputs[i]=m_svms[i]->classify();
00191         }
00192 
00193         for (int32_t i=0; i<num_vectors; i++)
00194         {
00195             int32_t winner=0;
00196             float64_t max_out=outputs[0]->get_label(i);
00197 
00198             for (int32_t j=1; j<m_num_svms; j++)
00199             {
00200                 float64_t out=outputs[j]->get_label(i);
00201 
00202                 if (out>max_out)
00203                 {
00204                     winner=j;
00205                     max_out=out;
00206                 }
00207             }
00208 
00209             result->set_label(i, winner);
00210         }
00211 
00212         for (int32_t i=0; i<m_num_svms; i++)
00213             SG_UNREF(outputs[i]);
00214 
00215         delete[] outputs;
00216     }
00217 
00218     return result;
00219 }
00220 
00221 float64_t CMultiClassSVM::classify_example(int32_t num)
00222 {
00223     if (multiclass_type==ONE_VS_REST)
00224         return classify_example_one_vs_rest(num);
00225     else if (multiclass_type==ONE_VS_ONE)
00226         return classify_example_one_vs_one(num);
00227     else
00228         SG_ERROR("unknown multiclass type\n");
00229 
00230     return 0;
00231 }
00232 
00233 float64_t CMultiClassSVM::classify_example_one_vs_rest(int32_t num)
00234 {
00235     ASSERT(m_num_svms>0);
00236     float64_t* outputs=new float64_t[m_num_svms];
00237     int32_t winner=0;
00238     float64_t max_out=m_svms[0]->classify_example(num);
00239 
00240     for (int32_t i=1; i<m_num_svms; i++)
00241     {
00242         outputs[i]=m_svms[i]->classify_example(num);
00243         if (outputs[i]>max_out)
00244         {
00245             winner=i;
00246             max_out=outputs[i];
00247         }
00248     }
00249     delete[] outputs;
00250 
00251     return winner;
00252 }
00253 
00254 float64_t CMultiClassSVM::classify_example_one_vs_one(int32_t num)
00255 {
00256     ASSERT(m_num_svms>0);
00257     ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2);
00258 
00259     int32_t* votes=new int32_t[m_num_classes];
00260     int32_t s=0;
00261 
00262     for (int32_t i=0; i<m_num_classes; i++)
00263     {
00264         for (int32_t j=i+1; j<m_num_classes; j++)
00265         {
00266             if (m_svms[s++]->classify_example(num)>0)
00267                 votes[i]++;
00268             else
00269                 votes[j]++;
00270         }
00271     }
00272 
00273     int32_t winner=0;
00274     int32_t max_votes=votes[0];
00275 
00276     for (int32_t i=1; i<m_num_classes; i++)
00277     {
00278         if (votes[i]>max_votes)
00279         {
00280             max_votes=votes[i];
00281             winner=i;
00282         }
00283     }
00284 
00285     delete[] votes;
00286 
00287     return winner;
00288 }
00289 
00290 bool CMultiClassSVM::load(FILE* modelfl)
00291 {
00292     bool result=true;
00293     char char_buffer[1024];
00294     int32_t int_buffer;
00295     float64_t double_buffer;
00296     int32_t line_number=1;
00297     int32_t svm_idx=-1;
00298 
00299     if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
00300         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00301     else
00302     {
00303         char_buffer[15]='\0';
00304         if (strcmp("%MultiClassSVM", char_buffer)!=0)
00305             SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
00306 
00307         line_number++;
00308     }
00309 
00310     int_buffer=0;
00311     if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1)
00312         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00313 
00314     if (!feof(modelfl))
00315         line_number++;
00316 
00317     if (int_buffer != multiclass_type)
00318         SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type);
00319 
00320     int_buffer=0;
00321     if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
00322         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00323 
00324     if (!feof(modelfl))
00325         line_number++;
00326 
00327     if (int_buffer < 2)
00328         SG_ERROR("less than 2 classes - how is this multiclass?\n");
00329 
00330     create_multiclass_svm(int_buffer);
00331 
00332     int_buffer=0;
00333     if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
00334         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00335 
00336     if (!feof(modelfl))
00337         line_number++;
00338 
00339     if (m_num_svms != int_buffer)
00340         SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer);
00341 
00342     if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
00343         SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00344 
00345     if (!feof(modelfl))
00346         line_number++;
00347 
00348     for (int32_t n=0; n<m_num_svms; n++)
00349     {
00350         svm_idx=-1;
00351         if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
00352         {
00353             result=false;
00354             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00355         }
00356         else
00357         {
00358             char_buffer[4]='\0';
00359             if (strncmp("%SVM", char_buffer, 4)!=0)
00360             {
00361                 result=false;
00362                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00363             }
00364 
00365             if (svm_idx != n)
00366                 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00367 
00368             line_number++;
00369         }
00370 
00371         int_buffer=0;
00372         if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
00373             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00374 
00375         if (svm_idx != n)
00376             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00377 
00378         if (!feof(modelfl))
00379             line_number++;
00380 
00381         SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
00382         CSVM* svm=new CSVM(int_buffer);
00383 
00384         double_buffer=0;
00385 
00386         if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
00387             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00388 
00389         if (svm_idx != n)
00390             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00391 
00392         if (!feof(modelfl))
00393             line_number++;
00394 
00395         svm->set_bias(double_buffer);
00396 
00397         if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
00398             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00399 
00400         if (svm_idx != n)
00401             SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
00402 
00403         if (!feof(modelfl))
00404             line_number++;
00405 
00406         for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
00407         {
00408             double_buffer=0;
00409             int_buffer=0;
00410 
00411             if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
00412                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00413 
00414             if (!feof(modelfl))
00415                 line_number++;
00416 
00417             svm->set_support_vector(i, int_buffer);
00418             svm->set_alpha(i, double_buffer);
00419         }
00420 
00421         if (fscanf(modelfl,"%2s", char_buffer) == EOF)
00422         {
00423             result=false;
00424             SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00425         }
00426         else
00427         {
00428             char_buffer[3]='\0';
00429             if (strcmp("];", char_buffer)!=0)
00430             {
00431                 result=false;
00432                 SG_ERROR( "error in svm file, line nr:%d\n", line_number);
00433             }
00434             line_number++;
00435         }
00436 
00437         set_svm(n, svm);
00438     }
00439 
00440     svm_loaded=result;
00441     return result;
00442 }
00443 
00444 bool CMultiClassSVM::save(FILE* modelfl)
00445 {
00446     if (!kernel)
00447         SG_ERROR("Kernel not defined!\n");
00448 
00449     if (!m_svms || m_num_svms<1 || m_num_classes <=2)
00450         SG_ERROR("Multiclass SVM not trained!\n");
00451 
00452     SG_INFO( "Writing model file...");
00453     fprintf(modelfl,"%%MultiClassSVM\n");
00454     fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type);
00455     fprintf(modelfl,"num_classes=%d;\n", m_num_classes);
00456     fprintf(modelfl,"num_svms=%d;\n", m_num_svms);
00457     fprintf(modelfl,"kernel='%s';\n", kernel->get_name());
00458 
00459     for (int32_t i=0; i<m_num_svms; i++)
00460     {
00461         CSVM* svm=m_svms[i];
00462         ASSERT(svm);
00463         fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1);
00464         fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
00465         fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
00466 
00467         fprintf(modelfl, "alphas%d=[\n", i);
00468 
00469         for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
00470         {
00471             fprintf(modelfl,"\t[%+10.16e,%d];\n",
00472                     svm->get_alpha(j), svm->get_support_vector(j));
00473         }
00474 
00475         fprintf(modelfl, "];\n");
00476     }
00477 
00478     SG_DONE();
00479     return true ;
00480 }

SHOGUN Machine Learning Toolbox - Documentation