[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 #ifndef VIGRA_RANDOM_FOREST_DT_HXX 00037 #define VIGRA_RANDOM_FOREST_DT_HXX 00038 00039 #include <algorithm> 00040 #include <map> 00041 #include <numeric> 00042 #include "vigra/multi_array.hxx" 00043 #include "vigra/mathutil.hxx" 00044 #include "vigra/array_vector.hxx" 00045 #include "vigra/sized_int.hxx" 00046 #include "vigra/matrix.hxx" 00047 #include "vigra/random.hxx" 00048 #include "vigra/functorexpression.hxx" 00049 #include <vector> 00050 00051 #include "rf_common.hxx" 00052 #include "rf_nodeproxy.hxx" 00053 namespace vigra 00054 { 00055 00056 namespace detail 00057 { 00058 // todo FINALLY DECIDE TO USE CAMEL CASE OR UNDERSCORES !!!!!! 00059 /** decisiontree classifier. 00060 * 00061 * This class is actually meant to be used in conjunction with the 00062 * Random Forest Classifier 00063 * - My suggestion would be to use the RandomForest classifier with 00064 * following parameters instead of directly using this 00065 * class (Preprocessing default values etc is handled in there): 00066 * 00067 * \code 00068 * RandomForest decisionTree(RF_Traits::Options_t() 00069 * .features_per_node(RF_ALL) 00070 * .tree_count(1) ); 00071 * \endcode 00072 * 00073 * \todo remove the classCount and featurecount from the topology 00074 * array. Pass ext_param_ to the nodes! 00075 * \todo Use relative addressing of nodes? 00076 */ 00077 class DecisionTree 00078 { 00079 /**\todo make private?*/ 00080 public: 00081 00082 /** value type of container array. use whenever referencing it 00083 */ 00084 typedef Int32 TreeInt; 00085 00086 ArrayVector<TreeInt> topology_; 00087 ArrayVector<double> parameters_; 00088 00089 ProblemSpec<> ext_param_; 00090 unsigned int classCount_; 00091 00092 00093 public: 00094 /** \brief Create tree with parameters */ 00095 template<class T> 00096 DecisionTree(ProblemSpec<T> ext_param) 00097 : 00098 ext_param_(ext_param), 00099 classCount_(ext_param.class_count_) 00100 {} 00101 00102 /**clears all memory used. 00103 */ 00104 void reset(unsigned int classCount = 0) 00105 { 00106 if(classCount) 00107 classCount_ = classCount; 00108 topology_.clear(); 00109 parameters_.clear(); 00110 } 00111 00112 00113 /** learn a Tree 00114 * 00115 * \tparam StackEntry_t The Stackentry containing Node/StackEntry_t 00116 * Information used during learing. Each Split functor has a 00117 * Stack entry associated with it (Split_t::StackEntry_t) 00118 * \sa RandomForest::learn() 00119 */ 00120 template < class U, class C, 00121 class U2, class C2, 00122 class StackEntry_t, 00123 class Stop_t, 00124 class Split_t, 00125 class Visitor_t, 00126 class Random_t > 00127 void learn( MultiArrayView<2, U, C> const & features, 00128 MultiArrayView<2, U2, C2> const & labels, 00129 StackEntry_t const & stack_entry, 00130 Split_t split, 00131 Stop_t stop, 00132 Visitor_t & visitor, 00133 Random_t & randint); 00134 template < class U, class C, 00135 class U2, class C2, 00136 class StackEntry_t, 00137 class Stop_t, 00138 class Split_t, 00139 class Visitor_t, 00140 class Random_t> 00141 void continueLearn( MultiArrayView<2, U, C> const & features, 00142 MultiArrayView<2, U2, C2> const & labels, 00143 StackEntry_t const & stack_entry, 00144 Split_t split, 00145 Stop_t stop, 00146 Visitor_t & visitor, 00147 Random_t & randint, 00148 //an index to which the last created exterior node will be moved (because it is not used anymore) 00149 int garbaged_child=-1); 00150 00151 /** is a node a Leaf Node? */ 00152 inline bool isLeafNode(TreeInt in) const 00153 { 00154 return (in & LeafNodeTag) == LeafNodeTag; 00155 } 00156 00157 /** data driven traversal from root to leaf 00158 * 00159 * traverse through tree with data given in features. Use Visitors to 00160 * collect statistics along the way. 00161 */ 00162 template<class U, class C, class Visitor_t> 00163 TreeInt getToLeaf(MultiArrayView<2, U, C> const & features, 00164 Visitor_t & visitor) const 00165 { 00166 TreeInt index = 2; 00167 while(!isLeafNode(topology_[index])) 00168 { 00169 visitor.visit_internal_node(*this, index, topology_[index],features); 00170 switch(topology_[index]) 00171 { 00172 case i_ThresholdNode: 00173 { 00174 Node<i_ThresholdNode> 00175 node(topology_, parameters_, index); 00176 index = node.next(features); 00177 break; 00178 } 00179 case i_HyperplaneNode: 00180 { 00181 Node<i_HyperplaneNode> 00182 node(topology_, parameters_, index); 00183 index = node.next(features); 00184 break; 00185 } 00186 case i_HypersphereNode: 00187 { 00188 Node<i_HypersphereNode> 00189 node(topology_, parameters_, index); 00190 index = node.next(features); 00191 break; 00192 } 00193 #if 0 00194 // for quick prototyping! has to be implemented. 00195 case i_VirtualNode: 00196 { 00197 Node<i_VirtualNode> 00198 node(topology_, parameters, index); 00199 index = node.next(features); 00200 } 00201 #endif 00202 default: 00203 vigra_fail("DecisionTree::getToLeaf():" 00204 "encountered unknown internal Node Type"); 00205 } 00206 } 00207 visitor.visit_external_node(*this, index, topology_[index],features); 00208 return index; 00209 } 00210 /** traverse tree to get statistics 00211 * 00212 * Tree is traversed in order the Nodes are in memory (i.e. if no 00213 * relearning//pruning scheme is utilized this will be pre order) 00214 */ 00215 template<class Visitor_t> 00216 void traverse_mem_order(Visitor_t visitor) const 00217 { 00218 TreeInt index = 2; 00219 Int32 ii = 0; 00220 while(index < topology_.size()) 00221 { 00222 if(isLeafNode(topology_[index])) 00223 { 00224 visitor 00225 .visit_external_node(*this, index, topology_[index]); 00226 } 00227 else 00228 { 00229 visitor 00230 ._internal_node(*this, index, topology_[index]); 00231 } 00232 } 00233 } 00234 00235 template<class Visitor_t> 00236 void traverse_post_order(Visitor_t visitor, TreeInt start = 2) const 00237 { 00238 typedef TinyVector<double, 2> Entry; 00239 std::vector<Entry > stack; 00240 std::vector<double> result_stack; 00241 stack.push_back(Entry(2, 0)); 00242 int addr; 00243 while(!stack.empty()) 00244 { 00245 addr = stack.back()[0]; 00246 NodeBase node(topology_, parameters_, stack.back()[0]); 00247 if(stack.back()[1] == 1) 00248 { 00249 stack.pop_back(); 00250 double leftRes = result_stack.back(); 00251 double rightRes = result_stack.back(); 00252 result_stack.pop_back(); 00253 result_stack.pop_back(); 00254 result_stack.push_back(rightRes+ leftRes); 00255 visitor.visit_internal_node(*this, 00256 addr, 00257 node.typeID(), 00258 rightRes+leftRes); 00259 } 00260 else 00261 { 00262 if(isLeafNode(node.typeID())) 00263 { 00264 visitor.visit_external_node(*this, 00265 addr, 00266 node.typeID(), 00267 node.weights()); 00268 stack.pop_back(); 00269 result_stack.push_back(node.weights()); 00270 } 00271 else 00272 { 00273 stack.back()[1] = 1; 00274 stack.push_back(Entry(node.child(0), 0)); 00275 stack.push_back(Entry(node.child(1), 0)); 00276 } 00277 00278 } 00279 } 00280 } 00281 00282 /** same thing as above, without any visitors */ 00283 template<class U, class C> 00284 TreeInt getToLeaf(MultiArrayView<2, U, C> const & features) const 00285 { 00286 RF_Traits::StopVisiting_t stop; 00287 return getToLeaf(features, stop); 00288 } 00289 00290 00291 template <class U, class C> 00292 ArrayVector<double>::iterator 00293 predict(MultiArrayView<2, U, C> const & features) const 00294 { 00295 TreeInt nodeindex = getToLeaf(features); 00296 switch(topology_[nodeindex]) 00297 { 00298 case e_ConstProbNode: 00299 return Node<e_ConstProbNode>(topology_, 00300 parameters_, 00301 nodeindex).prob_begin(); 00302 break; 00303 #if 0 00304 //first make the Logistic regression stuff... 00305 case e_LogRegProbNode: 00306 return Node<e_LogRegProbNode>(topology_, 00307 parameters_, 00308 nodeindex).prob_begin(); 00309 #endif 00310 default: 00311 vigra_fail("DecisionTree::predict() :" 00312 " encountered unknown external Node Type"); 00313 } 00314 return ArrayVector<double>::iterator(); 00315 } 00316 00317 00318 00319 template <class U, class C> 00320 Int32 predictLabel(MultiArrayView<2, U, C> const & features) const 00321 { 00322 ArrayVector<double>::const_iterator weights = predict(features); 00323 return argMax(weights, weights+classCount_) - weights; 00324 } 00325 00326 }; 00327 00328 00329 template < class U, class C, 00330 class U2, class C2, 00331 class StackEntry_t, 00332 class Stop_t, 00333 class Split_t, 00334 class Visitor_t, 00335 class Random_t> 00336 void DecisionTree::learn( MultiArrayView<2, U, C> const & features, 00337 MultiArrayView<2, U2, C2> const & labels, 00338 StackEntry_t const & stack_entry, 00339 Split_t split, 00340 Stop_t stop, 00341 Visitor_t & visitor, 00342 Random_t & randint) 00343 { 00344 this->reset(); 00345 topology_.reserve(256); 00346 parameters_.reserve(256); 00347 topology_.push_back(features.shape(1)); 00348 topology_.push_back(classCount_); 00349 continueLearn(features,labels,stack_entry,split,stop,visitor,randint); 00350 } 00351 00352 template < class U, class C, 00353 class U2, class C2, 00354 class StackEntry_t, 00355 class Stop_t, 00356 class Split_t, 00357 class Visitor_t, 00358 class Random_t> 00359 void DecisionTree::continueLearn( MultiArrayView<2, U, C> const & features, 00360 MultiArrayView<2, U2, C2> const & labels, 00361 StackEntry_t const & stack_entry, 00362 Split_t split, 00363 Stop_t stop, 00364 Visitor_t & visitor, 00365 Random_t & randint, 00366 //an index to which the last created exterior node will be moved (because it is not used anymore) 00367 int garbaged_child) 00368 { 00369 std::vector<StackEntry_t> stack; 00370 stack.reserve(128); 00371 ArrayVector<StackEntry_t> child_stack_entry(2, stack_entry); 00372 stack.push_back(stack_entry); 00373 size_t last_node_pos = 0; 00374 StackEntry_t top=stack.back(); 00375 00376 while(!stack.empty()) 00377 { 00378 00379 // Take an element of the stack. Obvious ain't it? 00380 top = stack.back(); 00381 stack.pop_back(); 00382 00383 // Make sure no data from the last round has remained in Pipeline; 00384 child_stack_entry[0].reset(); 00385 child_stack_entry[1].reset(); 00386 split.reset(); 00387 00388 00389 //Either the Stopping criterion decides that the split should 00390 //produce a Terminal Node or the Split itself decides what 00391 //kind of node to make 00392 TreeInt NodeID; 00393 if(stop(top)) 00394 NodeID = split.makeTerminalNode(features, 00395 labels, 00396 top, 00397 randint); 00398 else 00399 NodeID = split.findBestSplit(features, 00400 labels, 00401 top, 00402 child_stack_entry, 00403 randint); 00404 00405 // do some visiting yawn - just added this comment as eye candy 00406 // (looks odd otherwise with my syntax highlighting.... 00407 visitor.visit_after_split(*this, split, top, 00408 child_stack_entry[0], 00409 child_stack_entry[1], 00410 features, 00411 labels); 00412 00413 00414 // Update the Child entries of the parent 00415 // Using InteriorNodeBase because exact parameter form not needed. 00416 // look at the Node base before getting scared. 00417 last_node_pos = topology_.size(); 00418 if(top.leftParent != StackEntry_t::DecisionTreeNoParent) 00419 { 00420 NodeBase(topology_, 00421 parameters_, 00422 top.leftParent).child(0) = last_node_pos; 00423 } 00424 else if(top.rightParent != StackEntry_t::DecisionTreeNoParent) 00425 { 00426 NodeBase(topology_, 00427 parameters_, 00428 top.rightParent).child(1) = last_node_pos; 00429 } 00430 00431 00432 // Supply the split functor with the Node type it requires. 00433 // set the address to which the children of this node should point 00434 // to and push back children onto stack 00435 if(!isLeafNode(NodeID)) 00436 { 00437 child_stack_entry[0].leftParent = topology_.size(); 00438 child_stack_entry[1].rightParent = topology_.size(); 00439 child_stack_entry[0].rightParent = -1; 00440 child_stack_entry[1].leftParent = -1; 00441 stack.push_back(child_stack_entry[0]); 00442 stack.push_back(child_stack_entry[1]); 00443 } 00444 00445 //copy the newly created node form the split functor to the 00446 //decision tree. 00447 NodeBase(split.createNode(), topology_, parameters_ ); 00448 } 00449 if(garbaged_child!=-1) 00450 { 00451 Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).copy(Node<e_ConstProbNode>(topology_,parameters_,last_node_pos)); 00452 00453 int last_parameter_size = Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).parameters_size(); 00454 topology_.resize(last_node_pos); 00455 parameters_.resize(parameters_.size() - last_parameter_size); 00456 00457 if(top.leftParent != StackEntry_t::DecisionTreeNoParent) 00458 NodeBase(topology_, 00459 parameters_, 00460 top.leftParent).child(0) = garbaged_child; 00461 else if(top.rightParent != StackEntry_t::DecisionTreeNoParent) 00462 NodeBase(topology_, 00463 parameters_, 00464 top.rightParent).child(1) = garbaged_child; 00465 } 00466 } 00467 00468 } //namespace detail 00469 00470 } //namespace vigra 00471 00472 #endif //VIGRA_RANDOM_FOREST_DT_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|