SHOGUN v0.9.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include "classifier/svm/LibSVMMultiClass.h" 00012 #include "lib/io.h" 00013 00014 using namespace shogun; 00015 00016 CLibSVMMultiClass::CLibSVMMultiClass(LIBSVM_SOLVER_TYPE st) 00017 : CMultiClassSVM(ONE_VS_ONE), model(NULL), solver_type(st) 00018 { 00019 } 00020 00021 CLibSVMMultiClass::CLibSVMMultiClass(float64_t C, CKernel* k, CLabels* lab) 00022 : CMultiClassSVM(ONE_VS_ONE, C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC) 00023 { 00024 } 00025 00026 CLibSVMMultiClass::~CLibSVMMultiClass() 00027 { 00028 //SG_PRINT("deleting LibSVM\n"); 00029 } 00030 00031 bool CLibSVMMultiClass::train(CFeatures* data) 00032 { 00033 struct svm_node* x_space; 00034 00035 problem = svm_problem(); 00036 00037 ASSERT(labels && labels->get_num_labels()); 00038 int32_t num_classes = labels->get_num_classes(); 00039 problem.l=labels->get_num_labels(); 00040 SG_INFO( "%d trainlabels, %d classes\n", problem.l, num_classes); 00041 00042 if (data) 00043 { 00044 if (labels->get_num_labels() != data->get_num_vectors()) 00045 SG_ERROR("Number of training vectors does not match number of labels\n"); 00046 kernel->init(data, data); 00047 } 00048 00049 problem.y=new float64_t[problem.l]; 00050 problem.x=new struct svm_node*[problem.l]; 00051 problem.pv=new float64_t[problem.l]; 00052 problem.C=new float64_t[problem.l]; 00053 00054 x_space=new struct svm_node[2*problem.l]; 00055 00056 for (int32_t i=0; i<problem.l; i++) 00057 { 00058 problem.pv[i]=-1.0; 00059 problem.y[i]=labels->get_label(i); 00060 problem.x[i]=&x_space[2*i]; 00061 x_space[2*i].index=i; 00062 x_space[2*i+1].index=-1; 00063 } 00064 00065 ASSERT(kernel); 00066 00067 param.svm_type=solver_type; // C SVM or NU_SVM 00068 param.kernel_type = LINEAR; 00069 param.degree = 3; 00070 param.gamma = 0; // 1/k 00071 param.coef0 = 0; 00072 param.nu = get_nu(); // Nu 00073 param.kernel=kernel; 00074 param.cache_size = kernel->get_cache_size(); 00075 param.max_train_time = max_train_time; 00076 param.C = get_C1(); 00077 param.eps = epsilon; 00078 param.p = 0.1; 00079 param.shrinking = 1; 00080 param.nr_weight = 0; 00081 param.weight_label = NULL; 00082 param.weight = NULL; 00083 param.use_bias = get_bias_enabled(); 00084 00085 const char* error_msg = svm_check_parameter(&problem,¶m); 00086 00087 if(error_msg) 00088 SG_ERROR("Error: %s\n",error_msg); 00089 00090 model = svm_train(&problem, ¶m); 00091 00092 if (model) 00093 { 00094 if (model->nr_class!=num_classes) 00095 { 00096 SG_ERROR("LibSVM model->nr_class=%d while num_classes=%d\n", 00097 model->nr_class, num_classes); 00098 } 00099 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef)); 00100 create_multiclass_svm(num_classes); 00101 00102 int32_t* offsets=new int32_t[num_classes]; 00103 offsets[0]=0; 00104 00105 for (int32_t i=1; i<num_classes; i++) 00106 offsets[i] = offsets[i-1]+model->nSV[i-1]; 00107 00108 int32_t s=0; 00109 for (int32_t i=0; i<num_classes; i++) 00110 { 00111 for (int32_t j=i+1; j<num_classes; j++) 00112 { 00113 int32_t k, l; 00114 00115 float64_t sgn=1; 00116 if (model->label[i]>model->label[j]) 00117 sgn=-1; 00118 00119 int32_t num_sv=model->nSV[i]+model->nSV[j]; 00120 float64_t bias=-model->rho[s]; 00121 00122 ASSERT(num_sv>0); 00123 ASSERT(model->sv_coef[i] && model->sv_coef[j-1]); 00124 00125 CSVM* svm=new CSVM(num_sv); 00126 00127 svm->set_bias(sgn*bias); 00128 00129 int32_t sv_idx=0; 00130 for (k=0; k<model->nSV[i]; k++) 00131 { 00132 svm->set_support_vector(sv_idx, model->SV[offsets[i]+k]->index); 00133 svm->set_alpha(sv_idx, sgn*model->sv_coef[j-1][offsets[i]+k]); 00134 sv_idx++; 00135 } 00136 00137 for (k=0; k<model->nSV[j]; k++) 00138 { 00139 svm->set_support_vector(sv_idx, model->SV[offsets[j]+k]->index); 00140 svm->set_alpha(sv_idx, sgn*model->sv_coef[i][offsets[j]+k]); 00141 sv_idx++; 00142 } 00143 00144 int32_t idx=0; 00145 00146 if (sgn>0) 00147 { 00148 for (k=0; k<model->label[i]; k++) 00149 idx+=num_classes-k-1; 00150 00151 for (l=model->label[i]+1; l<model->label[j]; l++) 00152 idx++; 00153 } 00154 else 00155 { 00156 for (k=0; k<model->label[j]; k++) 00157 idx+=num_classes-k-1; 00158 00159 for (l=model->label[j]+1; l<model->label[i]; l++) 00160 idx++; 00161 } 00162 00163 00164 // if (sgn>0) 00165 // idx=((num_classes-1)*model->label[i]+model->label[j])/2; 00166 // else 00167 // idx=((num_classes-1)*model->label[j]+model->label[i])/2; 00168 // 00169 SG_DEBUG("svm[%d] has %d sv (total: %d), b=%f label:(%d,%d) -> svm[%d]\n", s, num_sv, model->l, bias, model->label[i], model->label[j], idx); 00170 00171 set_svm(idx, svm); 00172 s++; 00173 } 00174 } 00175 00176 CSVM::set_objective(model->objective); 00177 00178 delete[] offsets; 00179 delete[] problem.x; 00180 delete[] problem.y; 00181 delete[] x_space; 00182 00183 svm_destroy_model(model); 00184 model=NULL; 00185 00186 return true; 00187 } 00188 else 00189 return false; 00190 } 00191