SHOGUN v0.9.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include "classifier/svm/CPLEXSVM.h" 00012 #include "lib/common.h" 00013 00014 #ifdef USE_CPLEX 00015 #include "lib/io.h" 00016 #include "lib/Mathematics.h" 00017 #include "lib/Cplex.h" 00018 #include "features/Labels.h" 00019 00020 using namespace shogun; 00021 00022 CCPLEXSVM::CCPLEXSVM() 00023 : CSVM() 00024 { 00025 } 00026 00027 CCPLEXSVM::~CCPLEXSVM() 00028 { 00029 } 00030 00031 bool CCPLEXSVM::train(CFeatures* data) 00032 { 00033 bool result = false; 00034 CCplex cplex; 00035 00036 if (data) 00037 { 00038 if (labels->get_num_labels() != data->get_num_vectors()) 00039 SG_ERROR("Number of training vectors does not match number of labels\n"); 00040 kernel->init(data, data); 00041 } 00042 00043 if (cplex.init(E_QP)) 00044 { 00045 int32_t n,m; 00046 int32_t num_label=0; 00047 float64_t* y = labels->get_labels(num_label); 00048 float64_t* H = kernel->get_kernel_matrix<float64_t>(m, n, NULL); 00049 ASSERT(n>0 && n==m && n==num_label); 00050 float64_t* alphas=new float64_t[n]; 00051 float64_t* lb=new float64_t[n]; 00052 float64_t* ub=new float64_t[n]; 00053 00054 //hessian y'y.*K 00055 for (int32_t i=0; i<n; i++) 00056 { 00057 lb[i]=0; 00058 ub[i]=get_C1(); 00059 00060 for (int32_t j=0; j<n; j++) 00061 H[i*n+j]*=y[j]*y[i]; 00062 } 00063 00064 //feed qp to cplex 00065 00066 00067 int32_t j=0; 00068 for (int32_t i=0; i<n; i++) 00069 { 00070 if (alphas[i]>0) 00071 { 00072 //set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]); 00073 set_alpha(j, alphas[i]*labels->get_label(i)); 00074 set_support_vector(j, i); 00075 j++; 00076 } 00077 } 00078 //compute_objective(); 00079 SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias()); 00080 SG_INFO( "Number of SV: %ld\n", get_num_support_vectors()); 00081 00082 delete[] alphas; 00083 delete[] lb; 00084 delete[] ub; 00085 delete[] H; 00086 delete[] y; 00087 00088 result = true; 00089 } 00090 00091 if (!result) 00092 SG_ERROR( "cplex svm failed"); 00093 00094 return result; 00095 } 00096 #endif