SHOGUN v0.9.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2006-2009 Soeren Sonnenburg 00008 * Copyright (C) 2006-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include "classifier/svm/SVMLin.h" 00012 #include "features/Labels.h" 00013 #include "lib/Mathematics.h" 00014 #include "classifier/svm/ssl.h" 00015 #include "classifier/LinearClassifier.h" 00016 #include "features/DotFeatures.h" 00017 #include "features/Labels.h" 00018 00019 using namespace shogun; 00020 00021 CSVMLin::CSVMLin() 00022 : CLinearClassifier(), C1(1), C2(1), epsilon(1e-5), use_bias(true) 00023 { 00024 } 00025 00026 CSVMLin::CSVMLin( 00027 float64_t C, CDotFeatures* traindat, CLabels* trainlab) 00028 : CLinearClassifier(), C1(C), C2(C), epsilon(1e-5), use_bias(true) 00029 { 00030 set_features(traindat); 00031 set_labels(trainlab); 00032 } 00033 00034 00035 CSVMLin::~CSVMLin() 00036 { 00037 } 00038 00039 bool CSVMLin::train(CFeatures* data) 00040 { 00041 ASSERT(labels); 00042 00043 if (data) 00044 { 00045 if (!data->has_property(FP_DOT)) 00046 SG_ERROR("Specified features are not of type CDotFeatures\n"); 00047 set_features((CDotFeatures*) data); 00048 } 00049 00050 ASSERT(features); 00051 00052 int32_t num_train_labels=0; 00053 float64_t* train_labels=labels->get_labels(num_train_labels); 00054 int32_t num_feat=features->get_dim_feature_space(); 00055 int32_t num_vec=features->get_num_vectors(); 00056 00057 ASSERT(num_vec==num_train_labels); 00058 delete[] w; 00059 00060 struct options Options; 00061 struct data Data; 00062 struct vector_double Weights; 00063 struct vector_double Outputs; 00064 00065 Data.l=num_vec; 00066 Data.m=num_vec; 00067 Data.u=0; 00068 Data.n=num_feat+1; 00069 Data.nz=num_feat+1; 00070 Data.Y=train_labels; 00071 Data.features=features; 00072 Data.C = new float64_t[Data.l]; 00073 00074 Options.algo = SVM; 00075 Options.lambda=1/(2*get_C1()); 00076 Options.lambda_u=1/(2*get_C1()); 00077 Options.S=10000; 00078 Options.R=0.5; 00079 Options.epsilon = get_epsilon(); 00080 Options.cgitermax=10000; 00081 Options.mfnitermax=50; 00082 Options.Cp = get_C2()/get_C1(); 00083 Options.Cn = 1; 00084 00085 if (use_bias) 00086 Options.bias=1.0; 00087 else 00088 Options.bias=0.0; 00089 00090 for (int32_t i=0;i<num_vec;i++) 00091 { 00092 if(train_labels[i]>0) 00093 Data.C[i]=Options.Cp; 00094 else 00095 Data.C[i]=Options.Cn; 00096 } 00097 ssl_train(&Data, &Options, &Weights, &Outputs); 00098 ASSERT(Weights.vec && Weights.d==num_feat+1); 00099 00100 float64_t sgn=train_labels[0]; 00101 for (int32_t i=0; i<num_feat+1; i++) 00102 Weights.vec[i]*=sgn; 00103 00104 set_w(Weights.vec, num_feat); 00105 set_bias(Weights.vec[num_feat]); 00106 00107 delete[] Weights.vec; 00108 delete[] Data.C; 00109 delete[] train_labels; 00110 delete[] Outputs.vec; 00111 return true; 00112 }