SHOGUN v0.9.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include "lib/common.h" 00012 #include "lib/io.h" 00013 #include "classifier/svm/MultiClassSVM.h" 00014 00015 using namespace shogun; 00016 00017 CMultiClassSVM::CMultiClassSVM(void) 00018 : CSVM(0), multiclass_type(ONE_VS_REST), m_num_svms(0), m_svms(NULL) 00019 { 00020 SG_UNSTABLE("CMultiClassSVM::CMultiClassSVM(void)", "\n"); 00021 init(); 00022 } 00023 00024 CMultiClassSVM::CMultiClassSVM(EMultiClassSVM type) 00025 : CSVM(0), multiclass_type(type), m_num_svms(0), m_svms(NULL) 00026 { 00027 init(); 00028 } 00029 00030 CMultiClassSVM::CMultiClassSVM( 00031 EMultiClassSVM type, float64_t C, CKernel* k, CLabels* lab) 00032 : CSVM(C, k, lab), multiclass_type(type), m_num_svms(0), m_svms(NULL) 00033 { 00034 init(); 00035 } 00036 00037 CMultiClassSVM::~CMultiClassSVM() 00038 { 00039 cleanup(); 00040 } 00041 00042 void 00043 CMultiClassSVM::init(void) 00044 { 00045 m_parameters->add((machine_int_t*) &multiclass_type, 00046 "multiclass_type", "Type of MultiClassSVM."); 00047 m_parameters->add(&m_num_classes, "m_num_classes", 00048 "Number of classes."); 00049 m_parameters->add_vector((CSGObject***) &m_svms, 00050 &m_num_svms, "m_svms"); 00051 } 00052 00053 void CMultiClassSVM::cleanup() 00054 { 00055 for (int32_t i=0; i<m_num_svms; i++) 00056 SG_UNREF(m_svms[i]); 00057 00058 delete[] m_svms; 00059 m_num_svms=0; 00060 m_svms=NULL; 00061 } 00062 00063 bool CMultiClassSVM::create_multiclass_svm(int32_t num_classes) 00064 { 00065 if (num_classes>0) 00066 { 00067 cleanup(); 00068 00069 m_num_classes=num_classes; 00070 00071 if (multiclass_type==ONE_VS_REST) 00072 m_num_svms=num_classes; 00073 else if (multiclass_type==ONE_VS_ONE) 00074 m_num_svms=num_classes*(num_classes-1)/2; 00075 else 00076 SG_ERROR("unknown multiclass type\n"); 00077 00078 m_svms=new CSVM*[m_num_svms]; 00079 if (m_svms) 00080 { 00081 memset(m_svms,0, m_num_svms*sizeof(CSVM*)); 00082 return true; 00083 } 00084 } 00085 return false; 00086 } 00087 00088 bool CMultiClassSVM::set_svm(int32_t num, CSVM* svm) 00089 { 00090 if (m_num_svms>0 && m_num_svms>num && num>=0 && svm) 00091 { 00092 SG_REF(svm); 00093 m_svms[num]=svm; 00094 return true; 00095 } 00096 return false; 00097 } 00098 00099 CLabels* CMultiClassSVM::classify() 00100 { 00101 if (multiclass_type==ONE_VS_REST) 00102 return classify_one_vs_rest(); 00103 else if (multiclass_type==ONE_VS_ONE) 00104 return classify_one_vs_one(); 00105 else 00106 SG_ERROR("unknown multiclass type\n"); 00107 00108 return NULL; 00109 } 00110 00111 CLabels* CMultiClassSVM::classify_one_vs_one() 00112 { 00113 ASSERT(m_num_svms>0); 00114 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2); 00115 CLabels* result=NULL; 00116 00117 if (!kernel) 00118 { 00119 SG_ERROR( "SVM can not proceed without kernel!\n"); 00120 return false ; 00121 } 00122 00123 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs()) 00124 { 00125 int32_t num_vectors=kernel->get_num_vec_rhs(); 00126 00127 result=new CLabels(num_vectors); 00128 SG_REF(result); 00129 00130 ASSERT(num_vectors==result->get_num_labels()); 00131 CLabels** outputs=new CLabels*[m_num_svms]; 00132 00133 for (int32_t i=0; i<m_num_svms; i++) 00134 { 00135 SG_INFO("num_svms:%d svm[%d]=0x%0X\n", m_num_svms, i, m_svms[i]); 00136 ASSERT(m_svms[i]); 00137 m_svms[i]->set_kernel(kernel); 00138 outputs[i]=m_svms[i]->classify(); 00139 } 00140 00141 int32_t* votes=new int32_t[m_num_classes]; 00142 for (int32_t v=0; v<num_vectors; v++) 00143 { 00144 int32_t s=0; 00145 memset(votes, 0, sizeof(int32_t)*m_num_classes); 00146 00147 for (int32_t i=0; i<m_num_classes; i++) 00148 { 00149 for (int32_t j=i+1; j<m_num_classes; j++) 00150 { 00151 if (outputs[s++]->get_label(v)>0) 00152 votes[i]++; 00153 else 00154 votes[j]++; 00155 } 00156 } 00157 00158 int32_t winner=0; 00159 int32_t max_votes=votes[0]; 00160 00161 for (int32_t i=1; i<m_num_classes; i++) 00162 { 00163 if (votes[i]>max_votes) 00164 { 00165 max_votes=votes[i]; 00166 winner=i; 00167 } 00168 } 00169 00170 result->set_label(v, winner); 00171 } 00172 00173 delete[] votes; 00174 00175 for (int32_t i=0; i<m_num_svms; i++) 00176 SG_UNREF(outputs[i]); 00177 delete[] outputs; 00178 } 00179 00180 return result; 00181 } 00182 00183 CLabels* CMultiClassSVM::classify_one_vs_rest() 00184 { 00185 ASSERT(m_num_svms>0); 00186 CLabels* result=NULL; 00187 00188 if (!kernel) 00189 { 00190 SG_ERROR( "SVM can not proceed without kernel!\n"); 00191 return false ; 00192 } 00193 00194 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs()) 00195 { 00196 int32_t num_vectors=kernel->get_num_vec_rhs(); 00197 00198 result=new CLabels(num_vectors); 00199 SG_REF(result); 00200 00201 ASSERT(num_vectors==result->get_num_labels()); 00202 CLabels** outputs=new CLabels*[m_num_svms]; 00203 00204 for (int32_t i=0; i<m_num_svms; i++) 00205 { 00206 ASSERT(m_svms[i]); 00207 m_svms[i]->set_kernel(kernel); 00208 outputs[i]=m_svms[i]->classify(); 00209 } 00210 00211 for (int32_t i=0; i<num_vectors; i++) 00212 { 00213 int32_t winner=0; 00214 float64_t max_out=outputs[0]->get_label(i); 00215 00216 for (int32_t j=1; j<m_num_svms; j++) 00217 { 00218 float64_t out=outputs[j]->get_label(i); 00219 00220 if (out>max_out) 00221 { 00222 winner=j; 00223 max_out=out; 00224 } 00225 } 00226 00227 result->set_label(i, winner); 00228 } 00229 00230 for (int32_t i=0; i<m_num_svms; i++) 00231 SG_UNREF(outputs[i]); 00232 00233 delete[] outputs; 00234 } 00235 00236 return result; 00237 } 00238 00239 float64_t CMultiClassSVM::classify_example(int32_t num) 00240 { 00241 if (multiclass_type==ONE_VS_REST) 00242 return classify_example_one_vs_rest(num); 00243 else if (multiclass_type==ONE_VS_ONE) 00244 return classify_example_one_vs_one(num); 00245 else 00246 SG_ERROR("unknown multiclass type\n"); 00247 00248 return 0; 00249 } 00250 00251 float64_t CMultiClassSVM::classify_example_one_vs_rest(int32_t num) 00252 { 00253 ASSERT(m_num_svms>0); 00254 float64_t* outputs=new float64_t[m_num_svms]; 00255 int32_t winner=0; 00256 float64_t max_out=m_svms[0]->classify_example(num); 00257 00258 for (int32_t i=1; i<m_num_svms; i++) 00259 { 00260 outputs[i]=m_svms[i]->classify_example(num); 00261 if (outputs[i]>max_out) 00262 { 00263 winner=i; 00264 max_out=outputs[i]; 00265 } 00266 } 00267 delete[] outputs; 00268 00269 return winner; 00270 } 00271 00272 float64_t CMultiClassSVM::classify_example_one_vs_one(int32_t num) 00273 { 00274 ASSERT(m_num_svms>0); 00275 ASSERT(m_num_svms==m_num_classes*(m_num_classes-1)/2); 00276 00277 int32_t* votes=new int32_t[m_num_classes]; 00278 int32_t s=0; 00279 00280 for (int32_t i=0; i<m_num_classes; i++) 00281 { 00282 for (int32_t j=i+1; j<m_num_classes; j++) 00283 { 00284 if (m_svms[s++]->classify_example(num)>0) 00285 votes[i]++; 00286 else 00287 votes[j]++; 00288 } 00289 } 00290 00291 int32_t winner=0; 00292 int32_t max_votes=votes[0]; 00293 00294 for (int32_t i=1; i<m_num_classes; i++) 00295 { 00296 if (votes[i]>max_votes) 00297 { 00298 max_votes=votes[i]; 00299 winner=i; 00300 } 00301 } 00302 00303 delete[] votes; 00304 00305 return winner; 00306 } 00307 00308 bool CMultiClassSVM::load(FILE* modelfl) 00309 { 00310 bool result=true; 00311 char char_buffer[1024]; 00312 int32_t int_buffer; 00313 float64_t double_buffer; 00314 int32_t line_number=1; 00315 int32_t svm_idx=-1; 00316 00317 SG_SET_LOCALE_C; 00318 00319 if (fscanf(modelfl,"%15s\n", char_buffer)==EOF) 00320 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00321 else 00322 { 00323 char_buffer[15]='\0'; 00324 if (strcmp("%MultiClassSVM", char_buffer)!=0) 00325 SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number); 00326 00327 line_number++; 00328 } 00329 00330 int_buffer=0; 00331 if (fscanf(modelfl," multiclass_type=%d; \n", &int_buffer) != 1) 00332 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00333 00334 if (!feof(modelfl)) 00335 line_number++; 00336 00337 if (int_buffer != multiclass_type) 00338 SG_ERROR("multiclass type does not match %ld vs. %ld\n", int_buffer, multiclass_type); 00339 00340 int_buffer=0; 00341 if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1) 00342 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00343 00344 if (!feof(modelfl)) 00345 line_number++; 00346 00347 if (int_buffer < 2) 00348 SG_ERROR("less than 2 classes - how is this multiclass?\n"); 00349 00350 create_multiclass_svm(int_buffer); 00351 00352 int_buffer=0; 00353 if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1) 00354 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00355 00356 if (!feof(modelfl)) 00357 line_number++; 00358 00359 if (m_num_svms != int_buffer) 00360 SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_num_svms, int_buffer); 00361 00362 if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1) 00363 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00364 00365 if (!feof(modelfl)) 00366 line_number++; 00367 00368 for (int32_t n=0; n<m_num_svms; n++) 00369 { 00370 svm_idx=-1; 00371 if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF) 00372 { 00373 result=false; 00374 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00375 } 00376 else 00377 { 00378 char_buffer[4]='\0'; 00379 if (strncmp("%SVM", char_buffer, 4)!=0) 00380 { 00381 result=false; 00382 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00383 } 00384 00385 if (svm_idx != n) 00386 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00387 00388 line_number++; 00389 } 00390 00391 int_buffer=0; 00392 if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2) 00393 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00394 00395 if (svm_idx != n) 00396 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00397 00398 if (!feof(modelfl)) 00399 line_number++; 00400 00401 SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx); 00402 CSVM* svm=new CSVM(int_buffer); 00403 00404 double_buffer=0; 00405 00406 if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2) 00407 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00408 00409 if (svm_idx != n) 00410 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00411 00412 if (!feof(modelfl)) 00413 line_number++; 00414 00415 svm->set_bias(double_buffer); 00416 00417 if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1) 00418 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00419 00420 if (svm_idx != n) 00421 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00422 00423 if (!feof(modelfl)) 00424 line_number++; 00425 00426 for (int32_t i=0; i<svm->get_num_support_vectors(); i++) 00427 { 00428 double_buffer=0; 00429 int_buffer=0; 00430 00431 if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2) 00432 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00433 00434 if (!feof(modelfl)) 00435 line_number++; 00436 00437 svm->set_support_vector(i, int_buffer); 00438 svm->set_alpha(i, double_buffer); 00439 } 00440 00441 if (fscanf(modelfl,"%2s", char_buffer) == EOF) 00442 { 00443 result=false; 00444 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00445 } 00446 else 00447 { 00448 char_buffer[3]='\0'; 00449 if (strcmp("];", char_buffer)!=0) 00450 { 00451 result=false; 00452 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00453 } 00454 line_number++; 00455 } 00456 00457 set_svm(n, svm); 00458 } 00459 00460 svm_loaded=result; 00461 00462 SG_RESET_LOCALE; 00463 return result; 00464 } 00465 00466 bool CMultiClassSVM::save(FILE* modelfl) 00467 { 00468 SG_SET_LOCALE_C; 00469 00470 if (!kernel) 00471 SG_ERROR("Kernel not defined!\n"); 00472 00473 if (!m_svms || m_num_svms<1 || m_num_classes <=2) 00474 SG_ERROR("Multiclass SVM not trained!\n"); 00475 00476 SG_INFO( "Writing model file..."); 00477 fprintf(modelfl,"%%MultiClassSVM\n"); 00478 fprintf(modelfl,"multiclass_type=%d;\n", multiclass_type); 00479 fprintf(modelfl,"num_classes=%d;\n", m_num_classes); 00480 fprintf(modelfl,"num_svms=%d;\n", m_num_svms); 00481 fprintf(modelfl,"kernel='%s';\n", kernel->get_name()); 00482 00483 for (int32_t i=0; i<m_num_svms; i++) 00484 { 00485 CSVM* svm=m_svms[i]; 00486 ASSERT(svm); 00487 fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_num_svms-1); 00488 fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors()); 00489 fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias()); 00490 00491 fprintf(modelfl, "alphas%d=[\n", i); 00492 00493 for(int32_t j=0; j<svm->get_num_support_vectors(); j++) 00494 { 00495 fprintf(modelfl,"\t[%+10.16e,%d];\n", 00496 svm->get_alpha(j), svm->get_support_vector(j)); 00497 } 00498 00499 fprintf(modelfl, "];\n"); 00500 } 00501 00502 SG_RESET_LOCALE; 00503 SG_DONE(); 00504 return true ; 00505 }