[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_preprocessing.hxx
1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 #ifndef VIGRA_RF_PREPROCESSING_HXX
37 #define VIGRA_RF_PREPROCESSING_HXX
38 
39 #include <limits>
40 #include "rf_common.hxx"
41 
42 namespace vigra
43 {
44 
45 /** Class used while preprocessing (currently used only during learn)
46  *
47  * This class is internally used by the Random Forest learn function.
48  * Different split functors may need to process the data in different manners
49  * (i.e., regression labels that should not be touched and classification
50  * labels that must be converted into a integral format)
51  *
52  * This Class only exists in specialized versions, where the Tag class is
53  * fixed.
54  *
55  * The Tag class is determined by Splitfunctor::Preprocessor_t . Currently
56  * it can either be ClassificationTag or RegressionTag. look At the
57  * RegressionTag specialisation for the basic interface if you ever happen
58  * to care.... - or need some sort of vague new preprocessor.
59  * new preprocessor ( Soft labels or whatever)
60  */
61 template<class Tag, class LabelType, class T1, class C1, class T2, class C2>
62 class Processor;
63 
64 namespace detail
65 {
66 
67  /* Common helper function used in all Processors.
68  * This function analyses the options struct and calculates the real
69  * values needed for the current problem (data)
70  */
71  template<class T>
72  void fill_external_parameters(RandomForestOptions const & options,
73  ProblemSpec<T> & ext_param)
74  {
75  // set correct value for mtry
76  switch(options.mtry_switch_)
77  {
78  case RF_SQRT:
79  ext_param.actual_mtry_ =
80  int(std::floor(
81  std::sqrt(double(ext_param.column_count_))
82  + 0.5));
83  break;
84  case RF_LOG:
85  // this is in Breimans original paper
86  ext_param.actual_mtry_ =
87  int(1+(std::log(double(ext_param.column_count_))
88  /std::log(2.0)));
89  break;
90  case RF_FUNCTION:
91  ext_param.actual_mtry_ =
92  options.mtry_func_(ext_param.column_count_);
93  break;
94  case RF_ALL:
95  ext_param.actual_mtry_ = ext_param.column_count_;
96  break;
97  default:
98  ext_param.actual_mtry_ =
99  options.mtry_;
100  }
101  // set correct value for msample
102  switch(options.training_set_calc_switch_)
103  {
104  case RF_CONST:
105  ext_param.actual_msample_ =
106  options.training_set_size_;
107  break;
108  case RF_PROPORTIONAL:
109  ext_param.actual_msample_ =
110  (int)std::ceil( options.training_set_proportion_ *
111  ext_param.row_count_);
112  break;
113  case RF_FUNCTION:
114  ext_param.actual_msample_ =
115  options.training_set_func_(ext_param.row_count_);
116  break;
117  default:
118  vigra_precondition(1!= 1, "unexpected error");
119 
120  }
121 
122  }
123 
124  /* Returns true if MultiArray contains NaNs
125  */
126  template<unsigned int N, class T, class C>
127  bool contains_nan(MultiArrayView<N, T, C> const & in)
128  {
129  for(int ii = 0; ii < in.size(); ++ii)
130  if(in[ii] != in[ii])
131  return true;
132  return false;
133  }
134 
135  /* Returns true if MultiArray contains Infs
136  */
137  template<unsigned int N, class T, class C>
138  bool contains_inf(MultiArrayView<N, T, C> const & in)
139  {
140  if(!std::numeric_limits<T>::has_infinity)
141  return false;
142  for(int ii = 0; ii < in.size(); ++ii)
143  if(in[ii] == std::numeric_limits<T>::infinity())
144  return true;
145  return false;
146  }
147 } // namespace detail
148 
149 
150 
151 /** Preprocessor used during Classification
152  *
153  * This class converts the labels int Integral labels which are used by the
154  * standard split functor to address memory in the node objects.
155  */
156 template<class LabelType, class T1, class C1, class T2, class C2>
157 class Processor<ClassificationTag, LabelType, T1, C1, T2, C2>
158 {
159  public:
160  typedef Int32 LabelInt;
164  MultiArrayView<2, T1, C1>const & features_;
165  MultiArray<2, LabelInt> intLabels_;
167 
168  template<class T>
169  Processor(MultiArrayView<2, T1, C1>const & features,
170  MultiArrayView<2, T2, C2>const & response,
171  RandomForestOptions &options,
172  ProblemSpec<T> &ext_param)
173  :
174  features_( features) // do not touch the features.
175  {
176  vigra_precondition(!detail::contains_nan(features), "Processor(): Feature Matrix "
177  "Contains NaNs");
178  vigra_precondition(!detail::contains_nan(response), "Processor(): Response "
179  "Contains NaNs");
180  vigra_precondition(!detail::contains_inf(features), "Processor(): Feature Matrix "
181  "Contains inf");
182  vigra_precondition(!detail::contains_inf(response), "Processor(): Response "
183  "Contains inf");
184  // set some of the problem specific parameters
185  ext_param.column_count_ = features.shape(1);
186  ext_param.row_count_ = features.shape(0);
187  ext_param.problem_type_ = CLASSIFICATION;
188  ext_param.used_ = true;
189  intLabels_.reshape(response.shape());
190 
191  //get the class labels
192  if(ext_param.class_count_ == 0)
193  {
194  // fill up a map with the current labels and then create the
195  // integral labels.
196  std::set<T2> labelToInt;
197  for(MultiArrayIndex k = 0; k < features.shape(0); ++k)
198  labelToInt.insert(response(k,0));
199  std::vector<T2> tmp_(labelToInt.begin(), labelToInt.end());
200  ext_param.classes_(tmp_.begin(), tmp_.end());
201  }
202  for(MultiArrayIndex k = 0; k < features.shape(0); ++k)
203  {
204  if(std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0)) == ext_param.classes.end())
205  {
206  throw std::runtime_error("unknown label type");
207  }
208  else
209  intLabels_(k, 0) = std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0))
210  - ext_param.classes.begin();
211  }
212  // set class weights
213  if(ext_param.class_weights_.size() == 0)
214  {
216  tmp((std::size_t)ext_param.class_count_,
217  NumericTraits<T2>::one());
218  ext_param.class_weights(tmp.begin(), tmp.end());
219  }
220 
221  // set mtry and msample
222  detail::fill_external_parameters(options, ext_param);
223 
224  // set strata
225  strata_ = intLabels_;
226 
227  }
228 
229  /** Access the processed features
230  */
231  MultiArrayView<2, T1, C1>const & features()
232  {
233  return features_;
234  }
235 
236  /** Access processed labels
237  */
239  {
240  return intLabels_;
241  }
242 
243  /** Access processed strata
244  */
246  {
247  return ArrayVectorView<LabelInt>(intLabels_.size(), intLabels_.data());
248  }
249 
250  /** Access strata fraction sized - not used currently
251  */
253  {
254  return ArrayVectorView< double>();
255  }
256 };
257 
258 
259 
260 /** Regression Preprocessor - This basically does not do anything with the
261  * data.
262  */
263 template<class LabelType, class T1, class C1, class T2, class C2>
264 class Processor<RegressionTag,LabelType, T1, C1, T2, C2>
265 {
266 public:
267  // only views are created - no data copied.
268  MultiArrayView<2, T1, C1> features_;
269  MultiArrayView<2, T2, C2> response_;
270  RandomForestOptions const & options_;
271  ProblemSpec<LabelType> const &
272  ext_param_;
273  // will only be filled if needed
274  MultiArray<2, int> strata_;
275  bool strata_filled;
276 
277  // copy the views.
278  template<class T>
280  MultiArrayView<2, T2, C2> response,
281  RandomForestOptions const & options,
282  ProblemSpec<T>& ext_param)
283  :
284  features_(features),
285  response_(response),
286  options_(options),
287  ext_param_(ext_param)
288  {
289  // set some of the problem specific parameters
290  ext_param.column_count_ = features.shape(1);
291  ext_param.row_count_ = features.shape(0);
292  ext_param.problem_type_ = REGRESSION;
293  ext_param.used_ = true;
294  detail::fill_external_parameters(options, ext_param);
295  vigra_precondition(!detail::contains_nan(features), "Processor(): Feature Matrix "
296  "Contains NaNs");
297  vigra_precondition(!detail::contains_nan(response), "Processor(): Response "
298  "Contains NaNs");
299  vigra_precondition(!detail::contains_inf(features), "Processor(): Feature Matrix "
300  "Contains inf");
301  vigra_precondition(!detail::contains_inf(response), "Processor(): Response "
302  "Contains inf");
303  strata_ = MultiArray<2, int> (MultiArrayShape<2>::type(response_.shape(0), 1));
304  ext_param.response_size_ = response.shape(1);
305  ext_param.class_count_ = response_.shape(1);
306  std::vector<T2> tmp_(ext_param.class_count_, 0);
307  ext_param.classes_(tmp_.begin(), tmp_.end());
308  }
309 
310  /** access preprocessed features
311  */
313  {
314  return features_;
315  }
316 
317  /** access preprocessed response
318  */
320  {
321  return response_;
322  }
323 
324  /** access strata - this is not used currently
325  */
327  {
328  return strata_;
329  }
330 };
331 }
332 #endif //VIGRA_RF_PREPROCESSING_HXX
333 
334 
335 

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.9.0 (Wed Feb 27 2013)