00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "lib/io.h"
00012 #include "classifier/svm/GMNPSVM.h"
00013 #include "classifier/svm/gmnplib.h"
00014
00015 #define INDEX(ROW,COL,DIM) (((COL)*(DIM))+(ROW))
00016 #define MINUS_INF INT_MIN
00017 #define PLUS_INF INT_MAX
00018 #define KDELTA(A,B) (A==B)
00019 #define KDELTA4(A1,A2,A3,A4) ((A1==A2)||(A1==A3)||(A1==A4)||(A2==A3)||(A2==A4)||(A3==A4))
00020
00021 using namespace shogun;
00022
00023 CGMNPSVM::CGMNPSVM()
00024 : CMultiClassSVM(ONE_VS_REST)
00025 {
00026 }
00027
00028 CGMNPSVM::CGMNPSVM(float64_t C, CKernel* k, CLabels* lab)
00029 : CMultiClassSVM(ONE_VS_REST, C, k, lab)
00030 {
00031 }
00032
00033 CGMNPSVM::~CGMNPSVM()
00034 {
00035 }
00036
00037 bool CGMNPSVM::train(CFeatures* data)
00038 {
00039 ASSERT(kernel);
00040 ASSERT(labels && labels->get_num_labels());
00041
00042 if (data)
00043 {
00044 if (data->get_num_vectors() != labels->get_num_labels())
00045 {
00046 SG_ERROR("Numbert of vectors (%d) does not match number of labels (%d)\n",
00047 data->get_num_vectors(), labels->get_num_labels());
00048 }
00049 kernel->init(data, data);
00050 }
00051
00052 int32_t num_data = labels->get_num_labels();
00053 int32_t num_classes = labels->get_num_classes();
00054 int32_t num_virtual_data= num_data*(num_classes-1);
00055
00056 SG_INFO( "%d trainlabels, %d classes\n", num_data, num_classes);
00057
00058 float64_t* vector_y = new float64_t[num_data];
00059 for (int32_t i=0; i<num_data; i++)
00060 {
00061 vector_y[i]= labels->get_label(i)+1;
00062
00063 }
00064
00065 float64_t C = get_C1();
00066 int32_t tmax = 1000000000;
00067 float64_t tolabs = 0;
00068 float64_t tolrel = epsilon;
00069
00070 float64_t reg_const=0;
00071 if( C!=0 )
00072 reg_const = 1/(2*C);
00073
00074
00075 float64_t* alpha = new float64_t[num_virtual_data];
00076 float64_t* vector_c = new float64_t[num_virtual_data];
00077 memset(vector_c, 0, num_virtual_data*sizeof(float64_t));
00078
00079 float64_t thlb = 10000000000.0;
00080 int32_t t = 0;
00081 float64_t* History = NULL;
00082 int32_t verb = 0;
00083
00084 CGMNPLib mnp(vector_y,kernel,num_data, num_virtual_data, num_classes, reg_const);
00085
00086 mnp.gmnp_imdm(vector_c, num_virtual_data, tmax,
00087 tolabs, tolrel, thlb, alpha, &t, &History, verb );
00088
00089
00090 float64_t* all_alphas= new float64_t[num_classes*num_data];
00091 memset(all_alphas,0,num_classes*num_data*sizeof(float64_t));
00092
00093
00094 float64_t* all_bs=new float64_t[num_classes];
00095 memset(all_bs,0,num_classes*sizeof(float64_t));
00096
00097
00098 for(int32_t i=0; i < num_classes; i++ )
00099 {
00100 for(int32_t j=0; j < num_virtual_data; j++ )
00101 {
00102 int32_t inx1=0;
00103 int32_t inx2=0;
00104
00105 mnp.get_indices2( &inx1, &inx2, j );
00106
00107 all_alphas[(inx1*num_classes)+i] +=
00108 alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2));
00109 all_bs[i] += alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2));
00110 }
00111 }
00112
00113 create_multiclass_svm(num_classes);
00114
00115 for (int32_t i=0; i<num_classes; i++)
00116 {
00117 int32_t num_sv=0;
00118 for (int32_t j=0; j<num_data; j++)
00119 {
00120 if (all_alphas[j*num_classes+i] != 0)
00121 num_sv++;
00122 }
00123 ASSERT(num_sv>0);
00124 SG_DEBUG("svm[%d] has %d sv, b=%f\n", i, num_sv, all_bs[i]);
00125
00126 CSVM* svm=new CSVM(num_sv);
00127
00128 int32_t k=0;
00129 for (int32_t j=0; j<num_data; j++)
00130 {
00131 if (all_alphas[j*num_classes+i] != 0)
00132 {
00133 svm->set_alpha(k, all_alphas[j*num_classes+i]);
00134 svm->set_support_vector(k, j);
00135 k++;
00136 }
00137 }
00138
00139 svm->set_bias(all_bs[i]);
00140 set_svm(i, svm);
00141 }
00142
00143 m_basealphas.resize(num_classes, ::std::vector<float64_t>(num_data,0));
00144 for(int j=0; j < num_virtual_data; j++ )
00145 {
00146 int inx1=0;
00147 int inx2=0;
00148
00149 mnp.get_indices2( &inx1, &inx2, j );
00150 m_basealphas[inx2-1][inx1]=alpha[j];
00151 }
00152
00153 delete[] vector_c;
00154 delete[] alpha;
00155 delete[] all_alphas;
00156 delete[] all_bs;
00157 delete[] vector_y;
00158 delete[] History;
00159
00160 return true;
00161 }
00162
00163 void CGMNPSVM::getbasealphas(::std::vector< ::std::vector<float64_t> > & basealphas)
00164 {
00165 basealphas=m_basealphas;
00166 }