SHOGUN
v1.1.0
Main Page
Related Pages
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Pages
src
shogun
classifier
svm
LibSVM.cpp
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 1999-2009 Soeren Sonnenburg
8
* Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9
*/
10
11
#include <
shogun/classifier/svm/LibSVM.h
>
12
#include <
shogun/io/SGIO.h
>
13
14
using namespace
shogun;
15
16
CLibSVM::CLibSVM
(LIBSVM_SOLVER_TYPE st)
17
:
CSVM
(), model(NULL), solver_type(st)
18
{
19
}
20
21
CLibSVM::CLibSVM
(
float64_t
C,
CKernel
* k,
CLabels
* lab)
22
:
CSVM
(C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC)
23
{
24
problem
= svm_problem();
25
}
26
27
CLibSVM::~CLibSVM
()
28
{
29
}
30
31
32
bool
CLibSVM::train_machine
(
CFeatures
* data)
33
{
34
struct
svm_node* x_space;
35
36
ASSERT
(
labels
&&
labels
->
get_num_labels
());
37
ASSERT
(
labels
->
is_two_class_labeling
());
38
39
if
(data)
40
{
41
if
(
labels
->
get_num_labels
() != data->
get_num_vectors
())
42
SG_ERROR
(
"Number of training vectors does not match number of labels\n"
);
43
kernel
->
init
(data, data);
44
}
45
46
problem
.l=
labels
->
get_num_labels
();
47
SG_INFO
(
"%d trainlabels\n"
,
problem
.l);
48
49
// set linear term
50
if
(
m_linear_term
.
vlen
>0)
51
{
52
if
(
labels
->
get_num_labels
()!=
m_linear_term
.
vlen
)
53
SG_ERROR
(
"Number of training vectors does not match length of linear term\n"
);
54
55
// set with linear term from base class
56
problem
.pv =
get_linear_term_array
();
57
}
58
else
59
{
60
// fill with minus ones
61
problem
.pv =
SG_MALLOC
(
float64_t
,
problem
.l);
62
63
for
(
int
i=0; i!=
problem
.l; i++)
64
problem
.pv[i] = -1.0;
65
}
66
67
problem
.y=
SG_MALLOC
(
float64_t
,
problem
.l);
68
problem
.x=
SG_MALLOC
(
struct
svm_node*,
problem
.l);
69
problem
.C=
SG_MALLOC
(
float64_t
,
problem
.l);
70
71
x_space=
SG_MALLOC
(
struct
svm_node, 2*
problem
.l);
72
73
for
(int32_t i=0; i<
problem
.l; i++)
74
{
75
problem
.y[i]=
labels
->
get_label
(i);
76
problem
.x[i]=&x_space[2*i];
77
x_space[2*i].index=i;
78
x_space[2*i+1].index=-1;
79
}
80
81
int32_t weights_label[2]={-1,+1};
82
float64_t
weights[2]={1.0,
get_C2
()/
get_C1
()};
83
84
ASSERT
(
kernel
&&
kernel
->
has_features
());
85
ASSERT
(
kernel
->
get_num_vec_lhs
()==
problem
.l);
86
87
param
.svm_type=
solver_type
;
// C SVM or NU_SVM
88
param
.kernel_type = LINEAR;
89
param
.degree = 3;
90
param
.gamma = 0;
// 1/k
91
param
.coef0 = 0;
92
param
.nu =
get_nu
();
93
param
.kernel=
kernel
;
94
param
.cache_size =
kernel
->
get_cache_size
();
95
param
.max_train_time =
max_train_time
;
96
param
.C =
get_C1
();
97
param
.eps =
epsilon
;
98
param
.p = 0.1;
99
param
.shrinking = 1;
100
param
.nr_weight = 2;
101
param
.weight_label = weights_label;
102
param
.weight = weights;
103
param
.use_bias =
get_bias_enabled
();
104
105
const
char
* error_msg = svm_check_parameter(&
problem
, &
param
);
106
107
if
(error_msg)
108
SG_ERROR
(
"Error: %s\n"
,error_msg);
109
110
model
= svm_train(&
problem
, &
param
);
111
112
if
(
model
)
113
{
114
ASSERT
(
model
->nr_class==2);
115
ASSERT
((
model
->l==0) || (
model
->l>0 &&
model
->SV &&
model
->sv_coef &&
model
->sv_coef[0]));
116
117
int32_t num_sv=
model
->l;
118
119
create_new_model
(num_sv);
120
CSVM::set_objective
(
model
->objective);
121
122
float64_t
sgn=
model
->label[0];
123
124
set_bias
(-sgn*
model
->rho[0]);
125
126
for
(int32_t i=0; i<num_sv; i++)
127
{
128
set_support_vector
(i, (
model
->SV[i])->index);
129
set_alpha
(i, sgn*
model
->sv_coef[0][i]);
130
}
131
132
SG_FREE
(
problem
.x);
133
SG_FREE
(
problem
.y);
134
SG_FREE
(
problem
.pv);
135
SG_FREE
(
problem
.C);
136
137
138
SG_FREE
(x_space);
139
140
svm_destroy_model(
model
);
141
model
=NULL;
142
return
true
;
143
}
144
else
145
return
false
;
146
}
SHOGUN
Machine Learning Toolbox - Documentation