[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_decisionTree.hxx
1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 #ifndef VIGRA_RANDOM_FOREST_DT_HXX
37 #define VIGRA_RANDOM_FOREST_DT_HXX
38 
39 #include <algorithm>
40 #include <map>
41 #include <numeric>
42 #include "vigra/multi_array.hxx"
43 #include "vigra/mathutil.hxx"
44 #include "vigra/array_vector.hxx"
45 #include "vigra/sized_int.hxx"
46 #include "vigra/matrix.hxx"
47 #include "vigra/random.hxx"
48 #include "vigra/functorexpression.hxx"
49 #include <vector>
50 
51 #include "rf_common.hxx"
52 #include "rf_visitors.hxx"
53 #include "rf_nodeproxy.hxx"
54 namespace vigra
55 {
56 
57 namespace detail
58 {
59  // todo FINALLY DECIDE TO USE CAMEL CASE OR UNDERSCORES !!!!!!
60 /* decisiontree classifier.
61  *
62  * This class is actually meant to be used in conjunction with the
63  * Random Forest Classifier
64  * - My suggestion would be to use the RandomForest classifier with
65  * following parameters instead of directly using this
66  * class (Preprocessing default values etc is handled in there):
67  *
68  * \code
69  * RandomForest decisionTree(RF_Traits::Options_t()
70  * .features_per_node(RF_ALL)
71  * .tree_count(1) );
72  * \endcode
73  *
74  * \todo remove the classCount and featurecount from the topology
75  * array. Pass ext_param_ to the nodes!
76  * \todo Use relative addressing of nodes?
77  */
78 class DecisionTree
79 {
80  /* \todo make private?*/
81  public:
82 
83  /* value type of container array. use whenever referencing it
84  */
85  typedef Int32 TreeInt;
86 
87  ArrayVector<TreeInt> topology_;
88  ArrayVector<double> parameters_;
89 
90  ProblemSpec<> ext_param_;
91  unsigned int classCount_;
92 
93 
94  public:
95  /* \brief Create tree with parameters */
96  template<class T>
97  DecisionTree(ProblemSpec<T> ext_param)
98  :
99  ext_param_(ext_param),
100  classCount_(ext_param.class_count_)
101  {}
102 
103  /* clears all memory used.
104  */
105  void reset(unsigned int classCount = 0)
106  {
107  if(classCount)
108  classCount_ = classCount;
109  topology_.clear();
110  parameters_.clear();
111  }
112 
113 
114  /* learn a Tree
115  *
116  * \tparam StackEntry_t The Stackentry containing Node/StackEntry_t
117  * Information used during learning. Each Split functor has a
118  * Stack entry associated with it (Split_t::StackEntry_t)
119  * \sa RandomForest::learn()
120  */
121  template < class U, class C,
122  class U2, class C2,
123  class StackEntry_t,
124  class Stop_t,
125  class Split_t,
126  class Visitor_t,
127  class Random_t >
128  void learn( MultiArrayView<2, U, C> const & features,
129  MultiArrayView<2, U2, C2> const & labels,
130  StackEntry_t const & stack_entry,
131  Split_t split,
132  Stop_t stop,
133  Visitor_t & visitor,
134  Random_t & randint);
135  template < class U, class C,
136  class U2, class C2,
137  class StackEntry_t,
138  class Stop_t,
139  class Split_t,
140  class Visitor_t,
141  class Random_t>
142  void continueLearn( MultiArrayView<2, U, C> const & features,
143  MultiArrayView<2, U2, C2> const & labels,
144  StackEntry_t const & stack_entry,
145  Split_t split,
146  Stop_t stop,
147  Visitor_t & visitor,
148  Random_t & randint,
149  //an index to which the last created exterior node will be moved (because it is not used anymore)
150  int garbaged_child=-1);
151 
152  /* is a node a Leaf Node? */
153  inline bool isLeafNode(TreeInt in) const
154  {
155  return (in & LeafNodeTag) == LeafNodeTag;
156  }
157 
158  /* data driven traversal from root to leaf
159  *
160  * traverse through tree with data given in features. Use Visitors to
161  * collect statistics along the way.
162  */
163  template<class U, class C, class Visitor_t>
164  TreeInt getToLeaf(MultiArrayView<2, U, C> const & features,
165  Visitor_t & visitor) const
166  {
167  TreeInt index = 2;
168  while(!isLeafNode(topology_[index]))
169  {
170  visitor.visit_internal_node(*this, index, topology_[index],features);
171  switch(topology_[index])
172  {
173  case i_ThresholdNode:
174  {
175  Node<i_ThresholdNode>
176  node(topology_, parameters_, index);
177  index = node.next(features);
178  break;
179  }
180  case i_HyperplaneNode:
181  {
182  Node<i_HyperplaneNode>
183  node(topology_, parameters_, index);
184  index = node.next(features);
185  break;
186  }
187  case i_HypersphereNode:
188  {
189  Node<i_HypersphereNode>
190  node(topology_, parameters_, index);
191  index = node.next(features);
192  break;
193  }
194 #if 0
195  // for quick prototyping! has to be implemented.
196  case i_VirtualNode:
197  {
198  Node<i_VirtualNode>
199  node(topology_, parameters, index);
200  index = node.next(features);
201  }
202 #endif
203  default:
204  vigra_fail("DecisionTree::getToLeaf():"
205  "encountered unknown internal Node Type");
206  }
207  }
208  visitor.visit_external_node(*this, index, topology_[index],features);
209  return index;
210  }
211  /* traverse tree to get statistics
212  *
213  * Tree is traversed in order the Nodes are in memory (i.e. if no
214  * relearning//pruning scheme is utilized this will be pre order)
215  */
216  template<class Visitor_t>
217  void traverse_mem_order(Visitor_t visitor) const
218  {
219  TreeInt index = 2;
220  Int32 ii = 0;
221  while(index < topology_.size())
222  {
223  if(isLeafNode(topology_[index]))
224  {
225  visitor
226  .visit_external_node(*this, index, topology_[index]);
227  }
228  else
229  {
230  visitor
231  ._internal_node(*this, index, topology_[index]);
232  }
233  }
234  }
235 
236  template<class Visitor_t>
237  void traverse_post_order(Visitor_t visitor, TreeInt start = 2) const
238  {
239  typedef TinyVector<double, 2> Entry;
240  std::vector<Entry > stack;
241  std::vector<double> result_stack;
242  stack.push_back(Entry(2, 0));
243  int addr;
244  while(!stack.empty())
245  {
246  addr = stack.back()[0];
247  NodeBase node(topology_, parameters_, stack.back()[0]);
248  if(stack.back()[1] == 1)
249  {
250  stack.pop_back();
251  double leftRes = result_stack.back();
252  double rightRes = result_stack.back();
253  result_stack.pop_back();
254  result_stack.pop_back();
255  result_stack.push_back(rightRes+ leftRes);
256  visitor.visit_internal_node(*this,
257  addr,
258  node.typeID(),
259  rightRes+leftRes);
260  }
261  else
262  {
263  if(isLeafNode(node.typeID()))
264  {
265  visitor.visit_external_node(*this,
266  addr,
267  node.typeID(),
268  node.weights());
269  stack.pop_back();
270  result_stack.push_back(node.weights());
271  }
272  else
273  {
274  stack.back()[1] = 1;
275  stack.push_back(Entry(node.child(0), 0));
276  stack.push_back(Entry(node.child(1), 0));
277  }
278 
279  }
280  }
281  }
282 
283  /* same thing as above, without any visitors */
284  template<class U, class C>
285  TreeInt getToLeaf(MultiArrayView<2, U, C> const & features) const
286  {
288  return getToLeaf(features, stop);
289  }
290 
291 
292  template <class U, class C>
293  ArrayVector<double>::iterator
294  predict(MultiArrayView<2, U, C> const & features) const
295  {
296  TreeInt nodeindex = getToLeaf(features);
297  switch(topology_[nodeindex])
298  {
299  case e_ConstProbNode:
300  return Node<e_ConstProbNode>(topology_,
301  parameters_,
302  nodeindex).prob_begin();
303  break;
304 #if 0
305  //first make the Logistic regression stuff...
306  case e_LogRegProbNode:
307  return Node<e_LogRegProbNode>(topology_,
308  parameters_,
309  nodeindex).prob_begin();
310 #endif
311  default:
312  vigra_fail("DecisionTree::predict() :"
313  " encountered unknown external Node Type");
314  }
315  return ArrayVector<double>::iterator();
316  }
317 
318 
319 
320  template <class U, class C>
321  Int32 predictLabel(MultiArrayView<2, U, C> const & features) const
322  {
323  ArrayVector<double>::const_iterator weights = predict(features);
324  return argMax(weights, weights+classCount_) - weights;
325  }
326 
327 };
328 
329 
330 template < class U, class C,
331  class U2, class C2,
332  class StackEntry_t,
333  class Stop_t,
334  class Split_t,
335  class Visitor_t,
336  class Random_t>
337 void DecisionTree::learn( MultiArrayView<2, U, C> const & features,
338  MultiArrayView<2, U2, C2> const & labels,
339  StackEntry_t const & stack_entry,
340  Split_t split,
341  Stop_t stop,
342  Visitor_t & visitor,
343  Random_t & randint)
344 {
345  this->reset();
346  topology_.reserve(256);
347  parameters_.reserve(256);
348  topology_.push_back(features.shape(1));
349  topology_.push_back(classCount_);
350  continueLearn(features,labels,stack_entry,split,stop,visitor,randint);
351 }
352 
353 template < class U, class C,
354  class U2, class C2,
355  class StackEntry_t,
356  class Stop_t,
357  class Split_t,
358  class Visitor_t,
359  class Random_t>
360 void DecisionTree::continueLearn( MultiArrayView<2, U, C> const & features,
361  MultiArrayView<2, U2, C2> const & labels,
362  StackEntry_t const & stack_entry,
363  Split_t split,
364  Stop_t stop,
365  Visitor_t & visitor,
366  Random_t & randint,
367  //an index to which the last created exterior node will be moved (because it is not used anymore)
368  int garbaged_child)
369 {
370  std::vector<StackEntry_t> stack;
371  stack.reserve(128);
372  ArrayVector<StackEntry_t> child_stack_entry(2, stack_entry);
373  stack.push_back(stack_entry);
374  size_t last_node_pos = 0;
375  StackEntry_t top=stack.back();
376 
377  while(!stack.empty())
378  {
379 
380  // Take an element of the stack. Obvious ain't it?
381  top = stack.back();
382  stack.pop_back();
383 
384  // Make sure no data from the last round has remained in Pipeline;
385  child_stack_entry[0].reset();
386  child_stack_entry[1].reset();
387  split.reset();
388 
389 
390  //Either the Stopping criterion decides that the split should
391  //produce a Terminal Node or the Split itself decides what
392  //kind of node to make
393  TreeInt NodeID;
394 
395  if(stop(top))
396  NodeID = split.makeTerminalNode(features,
397  labels,
398  top,
399  randint);
400  else
401  {
402  //TIC;
403  NodeID = split.findBestSplit(features,
404  labels,
405  top,
406  child_stack_entry,
407  randint);
408  //std::cerr << TOC <<" " << NodeID << ";" <<std::endl;
409  }
410 
411  // do some visiting yawn - just added this comment as eye candy
412  // (looks odd otherwise with my syntax highlighting....
413  visitor.visit_after_split(*this, split, top,
414  child_stack_entry[0],
415  child_stack_entry[1],
416  features,
417  labels);
418 
419 
420  // Update the Child entries of the parent
421  // Using InteriorNodeBase because exact parameter form not needed.
422  // look at the Node base before getting scared.
423  last_node_pos = topology_.size();
424  if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
425  {
426  NodeBase(topology_,
427  parameters_,
428  top.leftParent).child(0) = last_node_pos;
429  }
430  else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
431  {
432  NodeBase(topology_,
433  parameters_,
434  top.rightParent).child(1) = last_node_pos;
435  }
436 
437 
438  // Supply the split functor with the Node type it requires.
439  // set the address to which the children of this node should point
440  // to and push back children onto stack
441  if(!isLeafNode(NodeID))
442  {
443  child_stack_entry[0].leftParent = topology_.size();
444  child_stack_entry[1].rightParent = topology_.size();
445  child_stack_entry[0].rightParent = -1;
446  child_stack_entry[1].leftParent = -1;
447  stack.push_back(child_stack_entry[0]);
448  stack.push_back(child_stack_entry[1]);
449  }
450 
451  //copy the newly created node form the split functor to the
452  //decision tree.
453  NodeBase node(split.createNode(), topology_, parameters_ );
454  }
455  if(garbaged_child!=-1)
456  {
457  Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).copy(Node<e_ConstProbNode>(topology_,parameters_,last_node_pos));
458 
459  int last_parameter_size = Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).parameters_size();
460  topology_.resize(last_node_pos);
461  parameters_.resize(parameters_.size() - last_parameter_size);
462 
463  if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
464  NodeBase(topology_,
465  parameters_,
466  top.leftParent).child(0) = garbaged_child;
467  else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
468  NodeBase(topology_,
469  parameters_,
470  top.rightParent).child(1) = garbaged_child;
471  }
472 }
473 
474 } //namespace detail
475 
476 } //namespace vigra
477 
478 #endif //VIGRA_RANDOM_FOREST_DT_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.9.0 (Tue Oct 22 2013)