SHOGUN v0.9.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009 Soeren Sonnenburg 00008 * Written (W) 2009 Marius Kloft 00009 * Copyright (C) 2009 TU Berlin and Max-Planck-Society 00010 */ 00011 00012 00013 #include "kernel/Kernel.h" 00014 #include "classifier/svm/ScatterSVM.h" 00015 #include "kernel/ScatterKernelNormalizer.h" 00016 #include "lib/io.h" 00017 00018 using namespace shogun; 00019 00020 CScatterSVM::CScatterSVM(void) 00021 : CMultiClassSVM(ONE_VS_REST), scatter_type(NO_BIAS_LIBSVM), 00022 model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0) 00023 { 00024 SG_UNSTABLE("CScatterSVM::CScatterSVM(void)", "\n"); 00025 } 00026 00027 CScatterSVM::CScatterSVM(SCATTER_TYPE type) 00028 : CMultiClassSVM(ONE_VS_REST), scatter_type(type), model(NULL), 00029 norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0) 00030 { 00031 } 00032 00033 CScatterSVM::CScatterSVM(float64_t C, CKernel* k, CLabels* lab) 00034 : CMultiClassSVM(ONE_VS_REST, C, k, lab), scatter_type(NO_BIAS_LIBSVM), model(NULL), 00035 norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0) 00036 { 00037 } 00038 00039 CScatterSVM::~CScatterSVM() 00040 { 00041 delete[] norm_wc; 00042 delete[] norm_wcw; 00043 } 00044 00045 bool CScatterSVM::train(CFeatures* data) 00046 { 00047 ASSERT(labels && labels->get_num_labels()); 00048 m_num_classes = labels->get_num_classes(); 00049 int32_t num_vectors = labels->get_num_labels(); 00050 00051 if (data) 00052 { 00053 if (labels->get_num_labels() != data->get_num_vectors()) 00054 SG_ERROR("Number of training vectors does not match number of labels\n"); 00055 kernel->init(data, data); 00056 } 00057 00058 int32_t* numc=new int32_t[m_num_classes]; 00059 CMath::fill_vector(numc, m_num_classes, 0); 00060 00061 for (int32_t i=0; i<num_vectors; i++) 00062 numc[(int32_t) labels->get_int_label(i)]++; 00063 00064 int32_t Nc=0; 00065 int32_t Nmin=num_vectors; 00066 for (int32_t i=0; i<m_num_classes; i++) 00067 { 00068 if (numc[i]>0) 00069 { 00070 Nc++; 00071 Nmin=CMath::min(Nmin, numc[i]); 00072 } 00073 00074 } 00075 delete[] numc; 00076 m_num_classes=m_num_classes; 00077 00078 bool result=false; 00079 00080 if (scatter_type==NO_BIAS_LIBSVM) 00081 { 00082 result=train_no_bias_libsvm(); 00083 } 00084 00085 else if (scatter_type==TEST_RULE1 || scatter_type==TEST_RULE2) 00086 { 00087 float64_t nu_min=((float64_t) Nc)/num_vectors; 00088 float64_t nu_max=((float64_t) Nc)*Nmin/num_vectors; 00089 00090 SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max); 00091 00092 if (get_nu()<nu_min || get_nu()>nu_max) 00093 SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max); 00094 00095 result=train_testrule12(); 00096 } 00097 else 00098 SG_ERROR("Unknown Scatter type\n"); 00099 00100 return result; 00101 } 00102 00103 bool CScatterSVM::train_no_bias_libsvm() 00104 { 00105 struct svm_node* x_space; 00106 00107 problem.l=labels->get_num_labels(); 00108 SG_INFO( "%d trainlabels\n", problem.l); 00109 00110 problem.y=new float64_t[problem.l]; 00111 problem.x=new struct svm_node*[problem.l]; 00112 x_space=new struct svm_node[2*problem.l]; 00113 00114 for (int32_t i=0; i<problem.l; i++) 00115 { 00116 problem.y[i]=+1; 00117 problem.x[i]=&x_space[2*i]; 00118 x_space[2*i].index=i; 00119 x_space[2*i+1].index=-1; 00120 } 00121 00122 int32_t weights_label[2]={-1,+1}; 00123 float64_t weights[2]={1.0,get_C2()/get_C1()}; 00124 00125 ASSERT(kernel && kernel->has_features()); 00126 ASSERT(kernel->get_num_vec_lhs()==problem.l); 00127 00128 param.svm_type=C_SVC; // Nu MC SVM 00129 param.kernel_type = LINEAR; 00130 param.degree = 3; 00131 param.gamma = 0; // 1/k 00132 param.coef0 = 0; 00133 param.nu = get_nu(); // Nu 00134 CKernelNormalizer* prev_normalizer=kernel->get_normalizer(); 00135 kernel->set_normalizer(new CScatterKernelNormalizer( 00136 m_num_classes-1, -1, labels, prev_normalizer)); 00137 param.kernel=kernel; 00138 param.cache_size = kernel->get_cache_size(); 00139 param.C = 0; 00140 param.eps = epsilon; 00141 param.p = 0.1; 00142 param.shrinking = 0; 00143 param.nr_weight = 2; 00144 param.weight_label = weights_label; 00145 param.weight = weights; 00146 param.nr_class=m_num_classes; 00147 param.use_bias = get_bias_enabled(); 00148 00149 const char* error_msg = svm_check_parameter(&problem,¶m); 00150 00151 if(error_msg) 00152 SG_ERROR("Error: %s\n",error_msg); 00153 00154 model = svm_train(&problem, ¶m); 00155 kernel->set_normalizer(prev_normalizer); 00156 SG_UNREF(prev_normalizer); 00157 00158 if (model) 00159 { 00160 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef)); 00161 00162 ASSERT(model->nr_class==m_num_classes); 00163 create_multiclass_svm(m_num_classes); 00164 00165 rho=model->rho[0]; 00166 00167 delete[] norm_wcw; 00168 norm_wcw = new float64_t[m_num_svms]; 00169 00170 for (int32_t i=0; i<m_num_classes; i++) 00171 { 00172 int32_t num_sv=model->nSV[i]; 00173 00174 CSVM* svm=new CSVM(num_sv); 00175 svm->set_bias(model->rho[i+1]); 00176 norm_wcw[i]=model->normwcw[i]; 00177 00178 00179 for (int32_t j=0; j<num_sv; j++) 00180 { 00181 svm->set_alpha(j, model->sv_coef[i][j]); 00182 svm->set_support_vector(j, model->SV[i][j].index); 00183 } 00184 00185 set_svm(i, svm); 00186 } 00187 00188 delete[] problem.x; 00189 delete[] problem.y; 00190 delete[] x_space; 00191 for (int32_t i=0; i<m_num_classes; i++) 00192 { 00193 free(model->SV[i]); 00194 model->SV[i]=NULL; 00195 } 00196 svm_destroy_model(model); 00197 00198 if (scatter_type==TEST_RULE2) 00199 compute_norm_wc(); 00200 00201 model=NULL; 00202 return true; 00203 } 00204 else 00205 return false; 00206 } 00207 00208 00209 00210 bool CScatterSVM::train_testrule12() 00211 { 00212 struct svm_node* x_space; 00213 problem.l=labels->get_num_labels(); 00214 SG_INFO( "%d trainlabels\n", problem.l); 00215 00216 problem.y=new float64_t[problem.l]; 00217 problem.x=new struct svm_node*[problem.l]; 00218 x_space=new struct svm_node[2*problem.l]; 00219 00220 for (int32_t i=0; i<problem.l; i++) 00221 { 00222 problem.y[i]=labels->get_label(i); 00223 problem.x[i]=&x_space[2*i]; 00224 x_space[2*i].index=i; 00225 x_space[2*i+1].index=-1; 00226 } 00227 00228 int32_t weights_label[2]={-1,+1}; 00229 float64_t weights[2]={1.0,get_C2()/get_C1()}; 00230 00231 ASSERT(kernel && kernel->has_features()); 00232 ASSERT(kernel->get_num_vec_lhs()==problem.l); 00233 00234 param.svm_type=NU_MULTICLASS_SVC; // Nu MC SVM 00235 param.kernel_type = LINEAR; 00236 param.degree = 3; 00237 param.gamma = 0; // 1/k 00238 param.coef0 = 0; 00239 param.nu = get_nu(); // Nu 00240 param.kernel=kernel; 00241 param.cache_size = kernel->get_cache_size(); 00242 param.C = 0; 00243 param.eps = epsilon; 00244 param.p = 0.1; 00245 param.shrinking = 0; 00246 param.nr_weight = 2; 00247 param.weight_label = weights_label; 00248 param.weight = weights; 00249 param.nr_class=m_num_classes; 00250 param.use_bias = get_bias_enabled(); 00251 00252 const char* error_msg = svm_check_parameter(&problem,¶m); 00253 00254 if(error_msg) 00255 SG_ERROR("Error: %s\n",error_msg); 00256 00257 model = svm_train(&problem, ¶m); 00258 00259 if (model) 00260 { 00261 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef)); 00262 00263 ASSERT(model->nr_class==m_num_classes); 00264 create_multiclass_svm(m_num_classes); 00265 00266 rho=model->rho[0]; 00267 00268 delete[] norm_wcw; 00269 norm_wcw = new float64_t[m_num_svms]; 00270 00271 for (int32_t i=0; i<m_num_classes; i++) 00272 { 00273 int32_t num_sv=model->nSV[i]; 00274 00275 CSVM* svm=new CSVM(num_sv); 00276 svm->set_bias(model->rho[i+1]); 00277 norm_wcw[i]=model->normwcw[i]; 00278 00279 00280 for (int32_t j=0; j<num_sv; j++) 00281 { 00282 svm->set_alpha(j, model->sv_coef[i][j]); 00283 svm->set_support_vector(j, model->SV[i][j].index); 00284 } 00285 00286 set_svm(i, svm); 00287 } 00288 00289 delete[] problem.x; 00290 delete[] problem.y; 00291 delete[] x_space; 00292 for (int32_t i=0; i<m_num_classes; i++) 00293 { 00294 free(model->SV[i]); 00295 model->SV[i]=NULL; 00296 } 00297 svm_destroy_model(model); 00298 00299 if (scatter_type==TEST_RULE2) 00300 compute_norm_wc(); 00301 00302 model=NULL; 00303 return true; 00304 } 00305 else 00306 return false; 00307 } 00308 00309 void CScatterSVM::compute_norm_wc() 00310 { 00311 delete[] norm_wc; 00312 norm_wc = new float64_t[m_num_svms]; 00313 for (int32_t i=0; i<m_num_svms; i++) 00314 norm_wc[i]=0; 00315 00316 00317 for (int c=0; c<m_num_svms; c++) 00318 { 00319 CSVM* svm=m_svms[c]; 00320 int32_t num_sv = svm->get_num_support_vectors(); 00321 00322 for (int32_t i=0; i<num_sv; i++) 00323 { 00324 int32_t ii=svm->get_support_vector(i); 00325 for (int32_t j=0; j<num_sv; j++) 00326 { 00327 int32_t jj=svm->get_support_vector(j); 00328 norm_wc[c]+=svm->get_alpha(i)*kernel->kernel(ii,jj)*svm->get_alpha(j); 00329 } 00330 } 00331 } 00332 00333 for (int32_t i=0; i<m_num_svms; i++) 00334 norm_wc[i]=CMath::sqrt(norm_wc[i]); 00335 00336 CMath::display_vector(norm_wc, m_num_svms, "norm_wc"); 00337 } 00338 00339 CLabels* CScatterSVM::classify_one_vs_rest() 00340 { 00341 CLabels* output=NULL; 00342 if (!kernel) 00343 { 00344 SG_ERROR( "SVM can not proceed without kernel!\n"); 00345 return false ; 00346 } 00347 00348 if ( kernel && kernel->get_num_vec_lhs() && kernel->get_num_vec_rhs()) 00349 { 00350 int32_t num_vectors=kernel->get_num_vec_rhs(); 00351 00352 output=new CLabels(num_vectors); 00353 SG_REF(output); 00354 00355 if (scatter_type == TEST_RULE1) 00356 { 00357 ASSERT(m_num_svms>0); 00358 for (int32_t i=0; i<num_vectors; i++) 00359 output->set_label(i, classify_example(i)); 00360 } 00361 #ifdef USE_SVMLIGHT 00362 else if (scatter_type == NO_BIAS_SVMLIGHT) 00363 { 00364 float64_t* outputs=new float64_t[num_vectors*m_num_classes]; 00365 CMath::fill_vector(outputs,num_vectors*m_num_classes,0.0); 00366 00367 for (int32_t i=0; i<num_vectors; i++) 00368 { 00369 for (int32_t j=0; j<get_num_support_vectors(); j++) 00370 { 00371 float64_t score=kernel->kernel(get_support_vector(j), i)*get_alpha(j); 00372 int32_t label=labels->get_int_label(get_support_vector(j)); 00373 for (int32_t c=0; c<m_num_classes; c++) 00374 { 00375 float64_t s= (label==c) ? (m_num_classes-1) : (-1); 00376 outputs[c+i*m_num_classes]+=s*score; 00377 } 00378 } 00379 } 00380 00381 for (int32_t i=0; i<num_vectors; i++) 00382 { 00383 int32_t winner=0; 00384 float64_t max_out=outputs[i*m_num_classes+0]; 00385 00386 for (int32_t j=1; j<m_num_classes; j++) 00387 { 00388 float64_t out=outputs[i*m_num_classes+j]; 00389 00390 if (out>max_out) 00391 { 00392 winner=j; 00393 max_out=out; 00394 } 00395 } 00396 00397 output->set_label(i, winner); 00398 } 00399 00400 delete[] outputs; 00401 } 00402 #endif //USE_SVMLIGHT 00403 else 00404 { 00405 ASSERT(m_num_svms>0); 00406 ASSERT(num_vectors==output->get_num_labels()); 00407 CLabels** outputs=new CLabels*[m_num_svms]; 00408 00409 for (int32_t i=0; i<m_num_svms; i++) 00410 { 00411 //SG_PRINT("svm %d\n", i); 00412 ASSERT(m_svms[i]); 00413 m_svms[i]->set_kernel(kernel); 00414 m_svms[i]->set_labels(labels); 00415 outputs[i]=m_svms[i]->classify(); 00416 } 00417 00418 for (int32_t i=0; i<num_vectors; i++) 00419 { 00420 int32_t winner=0; 00421 float64_t max_out=outputs[0]->get_label(i)/norm_wc[0]; 00422 00423 for (int32_t j=1; j<m_num_svms; j++) 00424 { 00425 float64_t out=outputs[j]->get_label(i)/norm_wc[j]; 00426 00427 if (out>max_out) 00428 { 00429 winner=j; 00430 max_out=out; 00431 } 00432 } 00433 00434 output->set_label(i, winner); 00435 } 00436 00437 for (int32_t i=0; i<m_num_svms; i++) 00438 SG_UNREF(outputs[i]); 00439 00440 delete[] outputs; 00441 } 00442 } 00443 00444 return output; 00445 } 00446 00447 float64_t CScatterSVM::classify_example(int32_t num) 00448 { 00449 ASSERT(m_num_svms>0); 00450 float64_t* outputs=new float64_t[m_num_svms]; 00451 int32_t winner=0; 00452 00453 if (scatter_type == TEST_RULE1) 00454 { 00455 for (int32_t c=0; c<m_num_svms; c++) 00456 outputs[c]=m_svms[c]->get_bias()-rho; 00457 00458 for (int32_t c=0; c<m_num_svms; c++) 00459 { 00460 float64_t v=0; 00461 00462 for (int32_t i=0; i<m_svms[c]->get_num_support_vectors(); i++) 00463 { 00464 float64_t alpha=m_svms[c]->get_alpha(i); 00465 int32_t svidx=m_svms[c]->get_support_vector(i); 00466 v += alpha*kernel->kernel(svidx, num); 00467 } 00468 00469 outputs[c] += v; 00470 for (int32_t j=0; j<m_num_svms; j++) 00471 outputs[j] -= v/m_num_svms; 00472 } 00473 00474 for (int32_t j=0; j<m_num_svms; j++) 00475 outputs[j]/=norm_wcw[j]; 00476 00477 float64_t max_out=outputs[0]; 00478 for (int32_t j=0; j<m_num_svms; j++) 00479 { 00480 if (outputs[j]>max_out) 00481 { 00482 max_out=outputs[j]; 00483 winner=j; 00484 } 00485 } 00486 } 00487 #ifdef USE_SVMLIGHT 00488 else if (scatter_type == NO_BIAS_SVMLIGHT) 00489 { 00490 SG_ERROR("Use classify...\n"); 00491 } 00492 #endif //USE_SVMLIGHT 00493 else 00494 { 00495 float64_t max_out=m_svms[0]->classify_example(num)/norm_wc[0]; 00496 00497 for (int32_t i=1; i<m_num_svms; i++) 00498 { 00499 outputs[i]=m_svms[i]->classify_example(num)/norm_wc[i]; 00500 if (outputs[i]>max_out) 00501 { 00502 winner=i; 00503 max_out=outputs[i]; 00504 } 00505 } 00506 } 00507 00508 delete[] outputs; 00509 return winner; 00510 }