SHOGUN v0.9.0
ScatterSVM.cpp
浏览该文件的文档。
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2009 Soeren Sonnenburg
00008  * Written (W) 2009 Marius Kloft
00009  * Copyright (C) 2009 TU Berlin and Max-Planck-Society
00010  */
00011 
00012 
00013 #include "kernel/Kernel.h"
00014 #include "classifier/svm/ScatterSVM.h"
00015 #include "kernel/ScatterKernelNormalizer.h"
00016 #include "lib/io.h"
00017 
00018 using namespace shogun;
00019 
00020 CScatterSVM::CScatterSVM(void)
00021 : CMultiClassSVM(ONE_VS_REST), scatter_type(NO_BIAS_LIBSVM),
00022   model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00023 {
00024     SG_UNSTABLE("CScatterSVM::CScatterSVM(void)", "\n");
00025 }
00026 
00027 CScatterSVM::CScatterSVM(SCATTER_TYPE type)
00028 : CMultiClassSVM(ONE_VS_REST), scatter_type(type), model(NULL),
00029     norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00030 {
00031 }
00032 
00033 CScatterSVM::CScatterSVM(float64_t C, CKernel* k, CLabels* lab)
00034 : CMultiClassSVM(ONE_VS_REST, C, k, lab), scatter_type(NO_BIAS_LIBSVM), model(NULL),
00035     norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00036 {
00037 }
00038 
00039 CScatterSVM::~CScatterSVM()
00040 {
00041     delete[] norm_wc;
00042     delete[] norm_wcw;
00043 }
00044 
00045 bool CScatterSVM::train(CFeatures* data)
00046 {
00047     ASSERT(labels && labels->get_num_labels());
00048     m_num_classes = labels->get_num_classes();
00049     int32_t num_vectors = labels->get_num_labels();
00050 
00051     if (data)
00052     {
00053         if (labels->get_num_labels() != data->get_num_vectors())
00054             SG_ERROR("Number of training vectors does not match number of labels\n");
00055         kernel->init(data, data);
00056     }
00057 
00058     int32_t* numc=new int32_t[m_num_classes];
00059     CMath::fill_vector(numc, m_num_classes, 0);
00060 
00061     for (int32_t i=0; i<num_vectors; i++)
00062         numc[(int32_t) labels->get_int_label(i)]++;
00063 
00064     int32_t Nc=0;
00065     int32_t Nmin=num_vectors;
00066     for (int32_t i=0; i<m_num_classes; i++)
00067     {
00068         if (numc[i]>0)
00069         {
00070             Nc++;
00071             Nmin=CMath::min(Nmin, numc[i]);
00072         }
00073 
00074     }
00075     delete[] numc;
00076     m_num_classes=m_num_classes;
00077 
00078     bool result=false;
00079 
00080     if (scatter_type==NO_BIAS_LIBSVM)
00081     {
00082         result=train_no_bias_libsvm();
00083     }
00084 
00085     else if (scatter_type==TEST_RULE1 || scatter_type==TEST_RULE2) 
00086     {
00087         float64_t nu_min=((float64_t) Nc)/num_vectors;
00088         float64_t nu_max=((float64_t) Nc)*Nmin/num_vectors;
00089 
00090         SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max);
00091 
00092         if (get_nu()<nu_min || get_nu()>nu_max)
00093             SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max);
00094 
00095         result=train_testrule12();
00096     }
00097     else
00098         SG_ERROR("Unknown Scatter type\n"); 
00099 
00100     return result;
00101 }
00102 
00103 bool CScatterSVM::train_no_bias_libsvm()
00104 {
00105     struct svm_node* x_space;
00106 
00107     problem.l=labels->get_num_labels();
00108     SG_INFO( "%d trainlabels\n", problem.l);
00109 
00110     problem.y=new float64_t[problem.l];
00111     problem.x=new struct svm_node*[problem.l];
00112     x_space=new struct svm_node[2*problem.l];
00113 
00114     for (int32_t i=0; i<problem.l; i++)
00115     {
00116         problem.y[i]=+1;
00117         problem.x[i]=&x_space[2*i];
00118         x_space[2*i].index=i;
00119         x_space[2*i+1].index=-1;
00120     }
00121 
00122     int32_t weights_label[2]={-1,+1};
00123     float64_t weights[2]={1.0,get_C2()/get_C1()};
00124 
00125     ASSERT(kernel && kernel->has_features());
00126     ASSERT(kernel->get_num_vec_lhs()==problem.l);
00127 
00128     param.svm_type=C_SVC; // Nu MC SVM
00129     param.kernel_type = LINEAR;
00130     param.degree = 3;
00131     param.gamma = 0;    // 1/k
00132     param.coef0 = 0;
00133     param.nu = get_nu(); // Nu
00134     CKernelNormalizer* prev_normalizer=kernel->get_normalizer();
00135     kernel->set_normalizer(new CScatterKernelNormalizer(
00136                 m_num_classes-1, -1, labels, prev_normalizer));
00137     param.kernel=kernel;
00138     param.cache_size = kernel->get_cache_size();
00139     param.C = 0;
00140     param.eps = epsilon;
00141     param.p = 0.1;
00142     param.shrinking = 0;
00143     param.nr_weight = 2;
00144     param.weight_label = weights_label;
00145     param.weight = weights;
00146     param.nr_class=m_num_classes;
00147     param.use_bias = get_bias_enabled();
00148 
00149     const char* error_msg = svm_check_parameter(&problem,&param);
00150 
00151     if(error_msg)
00152         SG_ERROR("Error: %s\n",error_msg);
00153 
00154     model = svm_train(&problem, &param);
00155     kernel->set_normalizer(prev_normalizer);
00156     SG_UNREF(prev_normalizer);
00157 
00158     if (model)
00159     {
00160         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef));
00161 
00162         ASSERT(model->nr_class==m_num_classes);
00163         create_multiclass_svm(m_num_classes);
00164 
00165         rho=model->rho[0];
00166 
00167         delete[] norm_wcw;
00168         norm_wcw = new float64_t[m_num_svms];
00169 
00170         for (int32_t i=0; i<m_num_classes; i++)
00171         {
00172             int32_t num_sv=model->nSV[i];
00173 
00174             CSVM* svm=new CSVM(num_sv);
00175             svm->set_bias(model->rho[i+1]);
00176             norm_wcw[i]=model->normwcw[i];
00177 
00178 
00179             for (int32_t j=0; j<num_sv; j++)
00180             {
00181                 svm->set_alpha(j, model->sv_coef[i][j]);
00182                 svm->set_support_vector(j, model->SV[i][j].index);
00183             }
00184 
00185             set_svm(i, svm);
00186         }
00187 
00188         delete[] problem.x;
00189         delete[] problem.y;
00190         delete[] x_space;
00191         for (int32_t i=0; i<m_num_classes; i++)
00192         {
00193             free(model->SV[i]);
00194             model->SV[i]=NULL;
00195         }
00196         svm_destroy_model(model);
00197 
00198         if (scatter_type==TEST_RULE2)
00199             compute_norm_wc();
00200 
00201         model=NULL;
00202         return true;
00203     }
00204     else
00205         return false;
00206 }
00207 
00208 
00209 
00210 bool CScatterSVM::train_testrule12()
00211 {
00212     struct svm_node* x_space;
00213     problem.l=labels->get_num_labels();
00214     SG_INFO( "%d trainlabels\n", problem.l);
00215 
00216     problem.y=new float64_t[problem.l];
00217     problem.x=new struct svm_node*[problem.l];
00218     x_space=new struct svm_node[2*problem.l];
00219 
00220     for (int32_t i=0; i<problem.l; i++)
00221     {
00222         problem.y[i]=labels->get_label(i);
00223         problem.x[i]=&x_space[2*i];
00224         x_space[2*i].index=i;
00225         x_space[2*i+1].index=-1;
00226     }
00227 
00228     int32_t weights_label[2]={-1,+1};
00229     float64_t weights[2]={1.0,get_C2()/get_C1()};
00230 
00231     ASSERT(kernel && kernel->has_features());
00232     ASSERT(kernel->get_num_vec_lhs()==problem.l);
00233 
00234     param.svm_type=NU_MULTICLASS_SVC; // Nu MC SVM
00235     param.kernel_type = LINEAR;
00236     param.degree = 3;
00237     param.gamma = 0;    // 1/k
00238     param.coef0 = 0;
00239     param.nu = get_nu(); // Nu
00240     param.kernel=kernel;
00241     param.cache_size = kernel->get_cache_size();
00242     param.C = 0;
00243     param.eps = epsilon;
00244     param.p = 0.1;
00245     param.shrinking = 0;
00246     param.nr_weight = 2;
00247     param.weight_label = weights_label;
00248     param.weight = weights;
00249     param.nr_class=m_num_classes;
00250     param.use_bias = get_bias_enabled();
00251 
00252     const char* error_msg = svm_check_parameter(&problem,&param);
00253 
00254     if(error_msg)
00255         SG_ERROR("Error: %s\n",error_msg);
00256 
00257     model = svm_train(&problem, &param);
00258 
00259     if (model)
00260     {
00261         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef));
00262 
00263         ASSERT(model->nr_class==m_num_classes);
00264         create_multiclass_svm(m_num_classes);
00265 
00266         rho=model->rho[0];
00267 
00268         delete[] norm_wcw;
00269         norm_wcw = new float64_t[m_num_svms];
00270 
00271         for (int32_t i=0; i<m_num_classes; i++)
00272         {
00273             int32_t num_sv=model->nSV[i];
00274 
00275             CSVM* svm=new CSVM(num_sv);
00276             svm->set_bias(model->rho[i+1]);
00277             norm_wcw[i]=model->normwcw[i];
00278 
00279 
00280             for (int32_t j=0; j<num_sv; j++)
00281             {
00282                 svm->set_alpha(j, model->sv_coef[i][j]);
00283                 svm->set_support_vector(j, model->SV[i][j].index);
00284             }
00285 
00286             set_svm(i, svm);
00287         }
00288 
00289         delete[] problem.x;
00290         delete[] problem.y;
00291         delete[] x_space;
00292         for (int32_t i=0; i<m_num_classes; i++)
00293         {
00294             free(model->SV[i]);
00295             model->SV[i]=NULL;
00296         }
00297         svm_destroy_model(model);
00298 
00299         if (scatter_type==TEST_RULE2)
00300             compute_norm_wc();
00301 
00302         model=NULL;
00303         return true;
00304     }
00305     else
00306         return false;
00307 }
00308 
00309 void CScatterSVM::compute_norm_wc()
00310 {
00311     delete[] norm_wc;
00312     norm_wc = new float64_t[m_num_svms];
00313     for (int32_t i=0; i<m_num_svms; i++)
00314         norm_wc[i]=0;
00315 
00316 
00317     for (int c=0; c<m_num_svms; c++)
00318     {
00319         CSVM* svm=m_svms[c];
00320         int32_t num_sv = svm->get_num_support_vectors();
00321 
00322         for (int32_t i=0; i<num_sv; i++)
00323         {
00324             int32_t ii=svm->get_support_vector(i);
00325             for (int32_t j=0; j<num_sv; j++)
00326             {
00327                 int32_t jj=svm->get_support_vector(j);
00328                 norm_wc[c]+=svm->get_alpha(i)*kernel->kernel(ii,jj)*svm->get_alpha(j);
00329             }
00330         }
00331     }
00332 
00333     for (int32_t i=0; i<m_num_svms; i++)
00334         norm_wc[i]=CMath::sqrt(norm_wc[i]);
00335 
00336     CMath::display_vector(norm_wc, m_num_svms, "norm_wc");
00337 }
00338 
00339 CLabels* CScatterSVM::classify_one_vs_rest()
00340 {
00341     CLabels* output=NULL;
00342     if (!kernel)
00343     {
00344         SG_ERROR( "SVM can not proceed without kernel!\n");
00345         return false ;
00346     }
00347 
00348     if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs())
00349     {
00350         int32_t num_vectors=kernel->get_num_vec_rhs();
00351 
00352         output=new CLabels(num_vectors);
00353         SG_REF(output);
00354 
00355         if (scatter_type == TEST_RULE1)
00356         {
00357             ASSERT(m_num_svms>0);
00358             for (int32_t i=0; i<num_vectors; i++)
00359                 output->set_label(i, classify_example(i));
00360         }
00361 #ifdef USE_SVMLIGHT
00362         else if (scatter_type == NO_BIAS_SVMLIGHT)
00363         {
00364             float64_t* outputs=new float64_t[num_vectors*m_num_classes];
00365             CMath::fill_vector(outputs,num_vectors*m_num_classes,0.0);
00366 
00367             for (int32_t i=0; i<num_vectors; i++)
00368             {
00369                 for (int32_t j=0; j<get_num_support_vectors(); j++)
00370                 {
00371                     float64_t score=kernel->kernel(get_support_vector(j), i)*get_alpha(j);
00372                     int32_t label=labels->get_int_label(get_support_vector(j));
00373                     for (int32_t c=0; c<m_num_classes; c++)
00374                     {
00375                         float64_t s= (label==c) ? (m_num_classes-1) : (-1);
00376                         outputs[c+i*m_num_classes]+=s*score;
00377                     }
00378                 }
00379             }
00380 
00381             for (int32_t i=0; i<num_vectors; i++)
00382             {
00383                 int32_t winner=0;
00384                 float64_t max_out=outputs[i*m_num_classes+0];
00385 
00386                 for (int32_t j=1; j<m_num_classes; j++)
00387                 {
00388                     float64_t out=outputs[i*m_num_classes+j];
00389 
00390                     if (out>max_out)
00391                     {
00392                         winner=j;
00393                         max_out=out;
00394                     }
00395                 }
00396 
00397                 output->set_label(i, winner);
00398             }
00399 
00400             delete[] outputs;
00401         }
00402 #endif //USE_SVMLIGHT
00403         else
00404         {
00405             ASSERT(m_num_svms>0);
00406             ASSERT(num_vectors==output->get_num_labels());
00407             CLabels** outputs=new CLabels*[m_num_svms];
00408 
00409             for (int32_t i=0; i<m_num_svms; i++)
00410             {
00411                 //SG_PRINT("svm %d\n", i);
00412                 ASSERT(m_svms[i]);
00413                 m_svms[i]->set_kernel(kernel);
00414                 m_svms[i]->set_labels(labels);
00415                 outputs[i]=m_svms[i]->classify();
00416             }
00417 
00418             for (int32_t i=0; i<num_vectors; i++)
00419             {
00420                 int32_t winner=0;
00421                 float64_t max_out=outputs[0]->get_label(i)/norm_wc[0];
00422 
00423                 for (int32_t j=1; j<m_num_svms; j++)
00424                 {
00425                     float64_t out=outputs[j]->get_label(i)/norm_wc[j];
00426 
00427                     if (out>max_out)
00428                     {
00429                         winner=j;
00430                         max_out=out;
00431                     }
00432                 }
00433 
00434                 output->set_label(i, winner);
00435             }
00436 
00437             for (int32_t i=0; i<m_num_svms; i++)
00438                 SG_UNREF(outputs[i]);
00439 
00440             delete[] outputs;
00441         }
00442     }
00443 
00444     return output;
00445 }
00446 
00447 float64_t CScatterSVM::classify_example(int32_t num)
00448 {
00449     ASSERT(m_num_svms>0);
00450     float64_t* outputs=new float64_t[m_num_svms];
00451     int32_t winner=0;
00452 
00453     if (scatter_type == TEST_RULE1)
00454     {
00455         for (int32_t c=0; c<m_num_svms; c++)
00456             outputs[c]=m_svms[c]->get_bias()-rho;
00457 
00458         for (int32_t c=0; c<m_num_svms; c++)
00459         {
00460             float64_t v=0;
00461 
00462             for (int32_t i=0; i<m_svms[c]->get_num_support_vectors(); i++)
00463             {
00464                 float64_t alpha=m_svms[c]->get_alpha(i);
00465                 int32_t svidx=m_svms[c]->get_support_vector(i);
00466                 v += alpha*kernel->kernel(svidx, num);
00467             }
00468 
00469             outputs[c] += v;
00470             for (int32_t j=0; j<m_num_svms; j++)
00471                 outputs[j] -= v/m_num_svms;
00472         }
00473 
00474         for (int32_t j=0; j<m_num_svms; j++)
00475             outputs[j]/=norm_wcw[j];
00476 
00477         float64_t max_out=outputs[0];
00478         for (int32_t j=0; j<m_num_svms; j++)
00479         {
00480             if (outputs[j]>max_out)
00481             {
00482                 max_out=outputs[j];
00483                 winner=j;
00484             }
00485         }
00486     }
00487 #ifdef USE_SVMLIGHT
00488     else if (scatter_type == NO_BIAS_SVMLIGHT)
00489     {
00490         SG_ERROR("Use classify...\n");
00491     }
00492 #endif //USE_SVMLIGHT
00493     else
00494     {
00495         float64_t max_out=m_svms[0]->classify_example(num)/norm_wc[0];
00496 
00497         for (int32_t i=1; i<m_num_svms; i++)
00498         {
00499             outputs[i]=m_svms[i]->classify_example(num)/norm_wc[i];
00500             if (outputs[i]>max_out)
00501             {
00502                 winner=i;
00503                 max_out=outputs[i];
00504             }
00505         }
00506     }
00507 
00508     delete[] outputs;
00509     return winner;
00510 }

SHOGUN Machine Learning Toolbox - Documentation