SHOGUN v0.9.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include "classifier/svm/MPDSVM.h" 00012 #include "lib/io.h" 00013 #include "lib/common.h" 00014 #include "lib/Mathematics.h" 00015 00016 using namespace shogun; 00017 00018 CMPDSVM::CMPDSVM() 00019 : CSVM() 00020 { 00021 } 00022 00023 CMPDSVM::CMPDSVM(float64_t C, CKernel* k, CLabels* lab) 00024 : CSVM(C, k, lab) 00025 { 00026 } 00027 00028 CMPDSVM::~CMPDSVM() 00029 { 00030 } 00031 00032 bool CMPDSVM::train(CFeatures* data) 00033 { 00034 ASSERT(labels); 00035 ASSERT(kernel); 00036 00037 if (data) 00038 { 00039 if (labels->get_num_labels() != data->get_num_vectors()) 00040 SG_ERROR("Number of training vectors does not match number of labels\n"); 00041 kernel->init(data, data); 00042 } 00043 ASSERT(kernel->has_features()); 00044 00045 //const float64_t nu=0.32; 00046 const float64_t alpha_eps=1e-12; 00047 const float64_t eps=get_epsilon(); 00048 const int64_t maxiter = 1L<<30; 00049 //const bool nustop=false; 00050 //const int32_t k=2; 00051 const int32_t n=labels->get_num_labels(); 00052 ASSERT(n>0); 00053 //const float64_t d = 1.0/n/nu; //NUSVC 00054 const float64_t d = get_C1(); //CSVC 00055 const float64_t primaleps=eps; 00056 const float64_t dualeps=eps*n; //heuristic 00057 int64_t niter=0; 00058 00059 kernel_cache = new CCache<KERNELCACHE_ELEM>(kernel->get_cache_size(), n, n); 00060 float64_t* alphas=new float64_t[n]; 00061 float64_t* dalphas=new float64_t[n]; 00062 //float64_t* hessres=new float64_t[2*n]; 00063 float64_t* hessres=new float64_t[n]; 00064 //float64_t* F=new float64_t[2*n]; 00065 float64_t* F=new float64_t[n]; 00066 00067 //float64_t hessest[2]={0,0}; 00068 //float64_t hstep[2]; 00069 //float64_t etas[2]={0,0}; 00070 //float64_t detas[2]={0,1}; //NUSVC 00071 float64_t etas=0; 00072 float64_t detas=0; //CSVC 00073 float64_t hessest=0; 00074 float64_t hstep; 00075 00076 const float64_t stopfac = 1; 00077 00078 bool primalcool; 00079 bool dualcool; 00080 00081 //if (nustop) 00082 //etas[1] = 1; 00083 00084 for (int32_t i=0; i<n; i++) 00085 { 00086 alphas[i]=0; 00087 F[i]=labels->get_label(i); 00088 //F[i+n]=-1; 00089 hessres[i]=labels->get_label(i); 00090 //hessres[i+n]=-1; 00091 //dalphas[i]=F[i+n]*etas[1]; //NUSVC 00092 dalphas[i]=-1; //CSVC 00093 } 00094 00095 // go ... 00096 while (niter++ < maxiter) 00097 { 00098 int32_t maxpidx=-1; 00099 float64_t maxpviol = -1; 00100 //float64_t maxdviol = CMath::abs(detas[0]); 00101 float64_t maxdviol = CMath::abs(detas); 00102 bool free_alpha=false; 00103 00104 //if (CMath::abs(detas[1])> maxdviol) 00105 //maxdviol=CMath::abs(detas[1]); 00106 00107 // compute kkt violations with correct sign ... 00108 for (int32_t i=0; i<n; i++) 00109 { 00110 float64_t v=CMath::abs(dalphas[i]); 00111 00112 if (alphas[i] > 0 && alphas[i] < d) 00113 free_alpha=true; 00114 00115 if ( (dalphas[i]==0) || 00116 (alphas[i]==0 && dalphas[i] >0) || 00117 (alphas[i]==d && dalphas[i] <0) 00118 ) 00119 v=0; 00120 00121 if (v > maxpviol) 00122 { 00123 maxpviol=v; 00124 maxpidx=i; 00125 } // if we cannot improve on maxpviol, we can still improve by choosing a cached element 00126 else if (v == maxpviol) 00127 { 00128 if (kernel_cache->is_cached(i)) 00129 maxpidx=i; 00130 } 00131 } 00132 00133 if (maxpidx<0 || maxdviol<0) 00134 SG_ERROR( "no violation no convergence, should not happen!\n"); 00135 00136 // ... and evaluate stopping conditions 00137 //if (nustop) 00138 //stopfac = CMath::max(etas[1], 1e-10); 00139 //else 00140 //stopfac = 1; 00141 00142 if (niter%10000 == 0) 00143 { 00144 float64_t obj=0; 00145 00146 for (int32_t i=0; i<n; i++) 00147 { 00148 obj-=alphas[i]; 00149 for (int32_t j=0; j<n; j++) 00150 obj+=0.5*labels->get_label(i)*labels->get_label(j)*alphas[i]*alphas[j]*kernel->kernel(i,j); 00151 } 00152 00153 SG_DEBUG( "obj:%f pviol:%f dviol:%f maxpidx:%d iter:%d\n", obj, maxpviol, maxdviol, maxpidx, niter); 00154 } 00155 00156 //for (int32_t i=0; i<n; i++) 00157 // SG_DEBUG( "alphas:%f dalphas:%f\n", alphas[i], dalphas[i]); 00158 00159 primalcool = (maxpviol < primaleps*stopfac); 00160 dualcool = (maxdviol < dualeps*stopfac) || (!free_alpha); 00161 00162 // done? 00163 if (primalcool && dualcool) 00164 { 00165 if (!free_alpha) 00166 SG_INFO( " no free alpha, stopping! #iter=%d\n", niter); 00167 else 00168 SG_INFO( " done! #iter=%d\n", niter); 00169 break; 00170 } 00171 00172 00173 ASSERT(maxpidx>=0 && maxpidx<n); 00174 // hessian updates 00175 hstep=-hessres[maxpidx]/compute_H(maxpidx,maxpidx); 00176 //hstep[0]=-hessres[maxpidx]/(compute_H(maxpidx,maxpidx)+hessreg); 00177 //hstep[1]=-hessres[maxpidx+n]/(compute_H(maxpidx,maxpidx)+hessreg); 00178 00179 hessest-=F[maxpidx]*hstep; 00180 //hessest[0]-=F[maxpidx]*hstep[0]; 00181 //hessest[1]-=F[maxpidx+n]*hstep[1]; 00182 00183 // do primal updates .. 00184 float64_t tmpalpha = alphas[maxpidx] - dalphas[maxpidx]/compute_H(maxpidx,maxpidx); 00185 00186 if (tmpalpha > d-alpha_eps) 00187 tmpalpha = d; 00188 00189 if (tmpalpha < 0+alpha_eps) 00190 tmpalpha = 0; 00191 00192 // update alphas & dalphas & detas ... 00193 float64_t alphachange = tmpalpha - alphas[maxpidx]; 00194 alphas[maxpidx] = tmpalpha; 00195 00196 KERNELCACHE_ELEM* h=lock_kernel_row(maxpidx); 00197 for (int32_t i=0; i<n; i++) 00198 { 00199 hessres[i]+=h[i]*hstep; 00200 //hessres[i]+=h[i]*hstep[0]; 00201 //hessres[i+n]+=h[i]*hstep[1]; 00202 dalphas[i] +=h[i]*alphachange; 00203 } 00204 unlock_kernel_row(maxpidx); 00205 00206 detas+=F[maxpidx]*alphachange; 00207 //detas[0]+=F[maxpidx]*alphachange; 00208 //detas[1]+=F[maxpidx+n]*alphachange; 00209 00210 // if at primal minimum, do eta update ... 00211 if (primalcool) 00212 { 00213 //float64_t etachange[2] = { detas[0]/hessest[0] , detas[1]/hessest[1] }; 00214 float64_t etachange = detas/hessest; 00215 00216 etas+=etachange; 00217 //etas[0]+=etachange[0]; 00218 //etas[1]+=etachange[1]; 00219 00220 // update dalphas 00221 for (int32_t i=0; i<n; i++) 00222 dalphas[i]+= F[i] * etachange; 00223 //dalphas[i]+= F[i] * etachange[0] + F[i+n] * etachange[1]; 00224 } 00225 } 00226 00227 if (niter >= maxiter) 00228 SG_WARNING( "increase maxiter ... \n"); 00229 00230 00231 int32_t nsv=0; 00232 for (int32_t i=0; i<n; i++) 00233 { 00234 if (alphas[i]>0) 00235 nsv++; 00236 } 00237 00238 00239 create_new_model(nsv); 00240 //set_bias(etas[0]/etas[1]); 00241 set_bias(etas); 00242 00243 int32_t j=0; 00244 for (int32_t i=0; i<n; i++) 00245 { 00246 if (alphas[i]>0) 00247 { 00248 //set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]); 00249 set_alpha(j, alphas[i]*labels->get_label(i)); 00250 set_support_vector(j, i); 00251 j++; 00252 } 00253 } 00254 compute_svm_dual_objective(); 00255 SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias()); 00256 SG_INFO( "Number of SV: %ld\n", get_num_support_vectors()); 00257 00258 delete[] alphas; 00259 delete[] dalphas; 00260 delete[] hessres; 00261 delete[] F; 00262 delete kernel_cache; 00263 00264 return true; 00265 }