SHOGUN  v1.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
DomainAdaptationSVMLinear.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2007-2011 Christian Widmer
8  * Copyright (C) 2007-2011 Max-Planck-Society
9  */
10 
11 #include <shogun/lib/config.h>
12 
13 #ifdef HAVE_LAPACK
14 
16 #include <shogun/io/SGIO.h>
17 #include <shogun/base/Parameter.h>
18 #include <iostream>
19 #include <vector>
20 
21 using namespace shogun;
22 
23 
25 {
26  init(NULL, 0.0);
27 }
28 
29 
31 {
32  init(pre_svm, B_param);
33 
34 }
35 
36 
38 {
39 
41  SG_DEBUG("deleting DomainAdaptationSVMLinear\n");
42 }
43 
44 
45 void CDomainAdaptationSVMLinear::init(CLinearMachine* pre_svm, float64_t B_param)
46 {
47 
48  if (pre_svm)
49  {
50  // increase reference counts
51  SG_REF(pre_svm);
52 
53  // set bias of parent svm to zero
54  pre_svm->set_bias(0.0);
55  }
56 
57  this->presvm = pre_svm;
58  this->B = B_param;
59  this->train_factor = 1.0;
60 
62 
63  // invoke sanity check
65 
66  // serialization code
67  m_parameters->add((CSGObject**) &presvm, "presvm", "SVM to regularize against");
68  m_parameters->add(&B, "B", "Regularization strenth B.");
69  m_parameters->add(&train_factor, "train_factor", "train_factor");
70 
71 }
72 
73 
75 {
76 
77  if (!presvm) {
78 
79  SG_WARNING("presvm is null");
80 
81  } else {
82 
83  if (presvm->get_bias() != 0) {
84  SG_ERROR("presvm bias not set to zero");
85  }
86 
88  SG_ERROR("feature types do not agree");
89  }
90  }
91 
92  return true;
93 
94 }
95 
96 
98 {
99 
100  CDotFeatures* tmp_data;
101 
102  if (train_data)
103  {
104  if (labels->get_num_labels() != train_data->get_num_vectors())
105  SG_ERROR("Number of training vectors does not match number of labels\n");
106  tmp_data = train_data;
107 
108  } else {
109 
110  tmp_data = features;
111  }
112 
113  int32_t num_training_points = get_labels()->get_num_labels();
114 
115  std::vector<float64_t> lin_term = std::vector<float64_t>(num_training_points);
116 
117  if (presvm)
118  {
119  ASSERT(presvm->get_bias() == 0.0);
120 
121  // bias of parent SVM was set to zero in constructor, already contains B
122  CLabels* parent_svm_out = presvm->apply(tmp_data);
123 
124  SG_DEBUG("pre-computing linear term from presvm\n");
125 
126  // pre-compute linear term
127  for (int32_t i=0; i!=num_training_points; i++)
128  {
129  lin_term[i] = train_factor * B * get_label(i) * parent_svm_out->get_label(i) - 1.0;
130  }
131 
132  // set linear term for QP
133  this->set_linear_term(
134  SGVector<float64_t>(&lin_term[0], lin_term.size()));
135 
136  }
137 
138  /*
139  // warm-start liblinear
140  //TODO test this code, measure speed-ups
141  //presvm w stored in presvm
142  float64_t* tmp_w;
143  presvm->get_w(tmp_w, w_dim);
144 
145  //copy vector
146  float64_t* tmp_w_copy = SG_MALLOC(float64_t, w_dim);
147  std::copy(tmp_w, tmp_w + w_dim, tmp_w_copy);
148 
149  for (int32_t i=0; i!=w_dim; i++)
150  {
151  tmp_w_copy[i] = B * tmp_w_copy[i];
152  }
153 
154  //set w (copied in setter)
155  set_w(tmp_w_copy, w_dim);
156  SG_FREE(tmp_w_copy);
157  */
158 
159  bool success = false;
160 
161  //train SVM
162  if (train_data)
163  {
164  success = CLibLinear::train_machine(train_data);
165  } else {
166  success = CLibLinear::train_machine();
167  }
168 
169  //ASSERT(presvm)
170 
171  return success;
172 
173 }
174 
175 
177 {
178  return presvm;
179 }
180 
181 
183 {
184  return B;
185 }
186 
187 
189 {
190  return train_factor;
191 }
192 
193 
195 {
196  train_factor = factor;
197 }
198 
199 
201 {
202 
203  ASSERT(presvm->get_bias()==0.0);
204 
205  int32_t num_examples = data->get_num_vectors();
206 
207  CLabels* out_current = CLibLinear::apply(data);
208 
209  if (presvm)
210  {
211 
212  // recursive call if used on DomainAdaptationSVM object
213  CLabels* out_presvm = presvm->apply(data);
214 
215 
216  // combine outputs
217  for (int32_t i=0; i!=num_examples; i++)
218  {
219  float64_t out_combined = out_current->get_label(i) + B*out_presvm->get_label(i);
220  out_current->set_label(i, out_combined);
221  }
222 
223  }
224 
225 
226  return out_current;
227 
228 }
229 
230 #endif //HAVE_LAPACK
231 

SHOGUN Machine Learning Toolbox - Documentation