[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_nodeproxy.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 #ifndef VIGRA_RANDOM_FOREST_NP_HXX
00037 #define VIGRA_RANDOM_FOREST_NP_HXX
00038 
00039 #include <algorithm>
00040 #include <map>
00041 #include <numeric>
00042 #include "vigra/mathutil.hxx"
00043 #include "vigra/array_vector.hxx"
00044 #include "vigra/sized_int.hxx"
00045 #include "vigra/matrix.hxx"
00046 #include "vigra/random.hxx"
00047 #include "vigra/functorexpression.hxx"
00048 
00049 
00050 namespace vigra
00051 {
00052 
00053 
00054 
00055 enum NodeTags
00056 {
00057     UnFilledNode        = 42,
00058     AllColumns          = 0x00000000,
00059     ToBePrunedTag       = 0x80000000,
00060     LeafNodeTag         = 0x40000000,
00061 
00062     i_ThresholdNode     = 0,
00063     i_HyperplaneNode    = 1,
00064     i_HypersphereNode   = 2,
00065     e_ConstProbNode     = 0 | LeafNodeTag,
00066     e_LogRegProbNode    = 1 | LeafNodeTag
00067 };
00068 
00069 /** NodeBase class.
00070 
00071     \ingroup DecicionTree
00072 
00073     This class implements common features of all nodes.
00074     Memory Structure:
00075         Int32   Array:  TypeID, ParameterAddr, Child0, Child1, [ColumnData]0_
00076         double  Array:  NodeWeight, [Parameters]1_
00077 
00078         TODO: Throw away the crappy iterators and use vigra::ArrayVectorView
00079              it is not like anybody else is going to use this NodeBase class
00080              is it?
00081 
00082         TODO: use the RF_Traits::ProblemSpec_t to specify the external 
00083              parameters instead of the options.
00084 */
00085 
00086 
00087 class NodeBase
00088 {
00089   public:
00090     typedef Int32                               INT;
00091     typedef ArrayVector<INT>                    T_Container_type;
00092     typedef ArrayVector<double>                 P_Container_type;
00093     typedef T_Container_type::iterator          Topology_type;
00094     typedef P_Container_type::iterator          Parameter_type;
00095 
00096 
00097     mutable Topology_type                       topology_;
00098     int                                         topology_size_;
00099 
00100     mutable Parameter_type                      parameters_;
00101     int                                         parameter_size_ ;
00102 
00103     /** if numColumns = 0 then xrange is used as split axis
00104     **/
00105     static T_Container_type                     xrange;
00106 
00107         // Tree Parameters
00108     int                                         featureCount_;
00109     int                                         classCount_;
00110 
00111         // Node Parameters
00112     bool                                        hasData_;
00113 
00114 
00115 
00116 
00117     /** get Node Weight
00118      */
00119     double &      weights()
00120     {
00121             return parameters_begin()[0];
00122     }
00123 
00124     double const &      weights() const
00125     {
00126             return parameters_begin()[0];
00127     }
00128 
00129     /** has the data been set?
00130      * todo: throw this out - bad design
00131      */
00132     bool          data() const
00133     {
00134         return hasData_;
00135     }
00136 
00137     /** get the node type id
00138      * \sa NodeTags
00139      */
00140     INT&          typeID()
00141     {
00142         return topology_[0];
00143     }
00144 
00145     INT const &          typeID() const
00146     {
00147         return topology_[0];
00148     }
00149 
00150     /** Where in the parameter_ array are the weights?
00151      */
00152     INT &          parameter_addr()
00153     {
00154         return topology_[1];
00155     }
00156 
00157     INT const &    parameter_addr() const
00158     {
00159         return topology_[1];
00160     }
00161 
00162     /** Column Range **/
00163     Topology_type  column_data() const
00164     {
00165         return topology_ + 4 ;
00166     }
00167 
00168     /** get the start iterator to the columns
00169      *  - once again - throw out - static members are crap.
00170      */
00171     Topology_type columns_begin() const
00172     {
00173             return column_data()+1;
00174     }
00175 
00176     /** how many columns?
00177      */
00178     int      columns_size() const
00179     {
00180         if(*column_data() == AllColumns)
00181             return featureCount_;
00182         else
00183             return *column_data();;
00184     }
00185 
00186     /** end iterator to the columns
00187      */
00188     Topology_type  columns_end() const
00189     {
00190         return columns_begin() + columns_size();
00191     }
00192 
00193     /** Topology Range - gives access to the raw Topo memory
00194      * the size_ member was added as a result of premature 
00195      * optimisation.
00196      */ 
00197     Topology_type   topology_begin() const
00198     {
00199         return topology_;
00200     }
00201     Topology_type   topology_end() const
00202     {
00203         return topology_begin() + topology_size();
00204     }
00205     int          topology_size() const
00206     {
00207         return topology_size_;
00208     }
00209 
00210     /** Parameter Range **/
00211     Parameter_type  parameters_begin() const
00212     {
00213         return parameters_;
00214     }
00215     Parameter_type  parameters_end() const
00216     {
00217         return parameters_begin() + parameters_size();
00218     }
00219 
00220     int          parameters_size() const
00221     {
00222         return parameter_size_;
00223     }
00224 
00225 
00226     /** where are the child nodes?
00227      */
00228     INT &           child(Int32 l)
00229     {
00230         return topology_begin()[2+l];
00231     }
00232 
00233     /** where are the child nodes?
00234      */
00235     INT const  &           child(Int32 l) const
00236     {
00237         return topology_begin()[2+l];
00238     }
00239 
00240     /** Default Constructor**/
00241     NodeBase()
00242     :
00243                     hasData_(false)
00244     {}
00245     void copy(const NodeBase& o)
00246     {
00247         vigra_precondition(topology_size_==o.topology_size_,"Cannot copy nodes of different sizes");
00248         vigra_precondition(featureCount_==o.featureCount_,"Cannot copy nodes with different feature count");
00249         vigra_precondition(classCount_==o.classCount_,"Cannot copy nodes with different class counts");
00250         vigra_precondition(parameters_size() ==o.parameters_size(),"Cannot copy nodes with different paremater sizes");
00251         std::copy(o.topology_begin(), o.topology_end(), topology_);
00252         std::copy(o.parameters_begin(),o.parameters_end(), parameters_);
00253     }
00254 
00255     /** create ReadOnly Base Node at position n (actual length is unknown)
00256      * only common features i.e. children etc are accessible.
00257      */
00258     NodeBase(   T_Container_type const   &  topology,
00259                 P_Container_type const   &  parameter,
00260                 INT                         n)
00261     :
00262                     topology_   (const_cast<Topology_type>(topology.begin()+ n)),
00263                     topology_size_(4),
00264                     parameters_  (const_cast<Parameter_type>(parameter.begin() + parameter_addr())),
00265                     parameter_size_(1),
00266                     featureCount_(topology[0]),
00267                     classCount_(topology[1]),
00268                     hasData_(true)
00269     {
00270         /*while((int)xrange.size() <  featureCount_)
00271             xrange.push_back(xrange.size());*/
00272     }
00273 
00274     /** create ReadOnly node with known length (the parameter range is valid)
00275      */
00276     NodeBase(   int                      tLen,
00277                 int                      pLen,
00278                 T_Container_type const & topology,
00279                 P_Container_type const & parameter,
00280                 INT                         n)
00281     :
00282                     topology_   (const_cast<Topology_type>(topology.begin()+ n)),
00283                     topology_size_(tLen),
00284                     parameters_  (const_cast<Parameter_type>(parameter.begin() + parameter_addr())),
00285                     parameter_size_(pLen),
00286                     featureCount_(topology[0]),
00287                     classCount_(topology[1]),
00288                     hasData_(true)
00289     {
00290         /*while((int)xrange.size() <  featureCount_)
00291             xrange.push_back(xrange.size());*/
00292     }
00293     /** create ReadOnly node with known length 
00294      * from existing Node
00295      */
00296     NodeBase(   int                      tLen,
00297                 int                      pLen,
00298                 NodeBase &               node)
00299     :
00300                     topology_   (node.topology_),
00301                     topology_size_(tLen),
00302                     parameters_  (node.parameters_),
00303                     parameter_size_(pLen),
00304                     featureCount_(node.featureCount_),
00305                     classCount_(node.classCount_),
00306                     hasData_(true)
00307     {
00308         /*while((int)xrange.size() <  featureCount_)
00309             xrange.push_back(xrange.size());*/
00310     }
00311 
00312 
00313    /** create new Node at end of vector
00314     * \param tLen number of integers needed in the topolog vector
00315     * \param pLen number of parameters needed (this includes the node
00316     *           weight)
00317     * \param topology reference to Topology array of decision tree.
00318     * \param parameter reference to Parameter array of decision tree.
00319     **/
00320     NodeBase(   int                      tLen,
00321                 int                      pLen,
00322                 T_Container_type   &        topology,
00323                 P_Container_type   &        parameter)
00324     :
00325                     topology_size_(tLen),
00326                     parameter_size_(pLen),
00327                     featureCount_(topology[0]),
00328                     classCount_(topology[1]),
00329                     hasData_(true)
00330     {
00331         /*while((int)xrange.size() <  featureCount_)
00332             xrange.push_back(xrange.size());*/
00333 
00334         int n = topology.size();
00335         for(int ii = 0; ii < tLen; ++ii)
00336             topology.push_back(0);
00337         //topology.resize (n  + tLen);
00338 
00339         topology_           =   topology.begin()+ n;
00340         typeID()            =   UnFilledNode;
00341 
00342         parameter_addr()    =   parameter.size();
00343 
00344         //parameter.resize(parameter.size() + pLen);
00345         for(int ii = 0; ii < pLen; ++ii)
00346             parameter.push_back(0);
00347 
00348         parameters_          =   parameter.begin()+ parameter_addr();
00349         weights() = 1;
00350     }
00351 
00352 
00353   /** PseudoCopy Constructor  - 
00354    *
00355    * Copy Node to the end of a container. 
00356    * Since each Node views on different data there can't be a real 
00357    * copy constructor (unless both objects should point to the 
00358    * same underlying data.                                  
00359    */
00360     NodeBase(   NodeBase      const  &    toCopy,
00361                 T_Container_type      &    topology,
00362                 P_Container_type     &    parameter)
00363     :
00364                     topology_size_(toCopy.topology_size()),
00365                     parameter_size_(toCopy.parameters_size()),
00366                     featureCount_(topology[0]),
00367                     classCount_(topology[1]),
00368                     hasData_(true)
00369     {
00370         /*while((int)xrange.size() <  featureCount_)
00371             xrange.push_back(xrange.size());*/
00372 
00373         int n            = topology.size();
00374         for(int ii = 0; ii < toCopy.topology_size(); ++ii)
00375             topology.push_back(toCopy.topology_begin()[ii]);
00376 //        topology.insert(topology.end(), toCopy.topology_begin(), toCopy.topology_end());
00377         topology_           =   topology.begin()+ n;
00378         parameter_addr()    =   parameter.size();
00379         for(int ii = 0; ii < toCopy.parameters_size(); ++ii)
00380             parameter.push_back(toCopy.parameters_begin()[ii]);
00381 //        parameter.insert(parameter.end(), toCopy.parameters_begin(), toCopy.parameters_end());
00382         parameters_          =   parameter.begin()+ parameter_addr();
00383     }
00384 };
00385 
00386  NodeBase::T_Container_type NodeBase::xrange;
00387 
00388 
00389 
00390 template<NodeTags NodeType>
00391 class Node;
00392 
00393 template<>
00394 class Node<i_ThresholdNode>
00395 : public NodeBase
00396 {
00397 
00398 
00399     public:
00400     typedef NodeBase BT;
00401 
00402         /**constructors **/
00403 
00404     Node(   BT::T_Container_type &   topology,
00405             BT::P_Container_type &   param)
00406                 :   BT(5,2,topology, param)
00407     {
00408         BT::typeID() = i_ThresholdNode;
00409     }
00410 
00411     Node(   BT::T_Container_type const     &   topology,
00412             BT::P_Container_type const     &   param,
00413                     INT                   n             )
00414                 :   BT(5,2,topology, param, n)
00415     {}
00416 
00417     Node( BT & node_)
00418         :   BT(5, 2, node_) 
00419     {}
00420 
00421     double& threshold()
00422     {
00423         return BT::parameters_begin()[1];
00424     }
00425 
00426     double const & threshold() const
00427     {
00428         return BT::parameters_begin()[1];
00429     }
00430 
00431     BT::INT& column()
00432     {
00433         return BT::column_data()[0];
00434     }
00435     BT::INT const & column() const
00436     {
00437         return BT::column_data()[0];
00438     }
00439 
00440     template<class U, class C>
00441     BT::INT  next(MultiArrayView<2,U,C> const & feature) const
00442     {
00443         return (feature(0, column()) < threshold())? child(0):child(1);
00444     }
00445 };
00446 
00447 
00448 template<>
00449 class Node<i_HyperplaneNode>
00450 : public NodeBase
00451 {
00452     public:
00453 
00454     typedef NodeBase BT;
00455 
00456         /**constructors **/
00457 
00458     Node(           int                      nCol,
00459                     BT::T_Container_type    &   topology,
00460                     BT::P_Container_type    &   split_param)
00461                 :   BT(nCol + 5,nCol + 2,topology, split_param)
00462     {
00463         BT::typeID() = i_HyperplaneNode;
00464     }
00465 
00466     Node(           BT::T_Container_type  const  &   topology,
00467                     BT::P_Container_type  const  &   split_param,
00468                     int                  n             )
00469                 :   NodeBase(5 , 2,topology, split_param, n)
00470     {
00471         //TODO : is there a more elegant way to do this?
00472         BT::topology_size_ += BT::column_data()[0]== AllColumns ?
00473                                         0
00474                                     :   BT::column_data()[0];
00475         BT::parameter_size_ += BT::columns_size();
00476     }
00477 
00478     Node( BT & node_)
00479         :   BT(5, 2, node_) 
00480     {
00481         //TODO : is there a more elegant way to do this?
00482         BT::topology_size_ += BT::column_data()[0]== AllColumns ?
00483                                         0
00484                                     :   BT::column_data()[0];
00485         BT::parameter_size_ += BT::columns_size();
00486     }
00487 
00488 
00489     double const & intercept() const
00490     {
00491         return BT::parameters_begin()[1];
00492     }
00493     double& intercept()
00494     {
00495         return BT::parameters_begin()[1];
00496     }
00497 
00498     BT::Parameter_type weights() const
00499     {
00500         return BT::parameters_begin()+2;
00501     }
00502 
00503     BT::Parameter_type weights()
00504     {
00505         return BT::parameters_begin()+2;
00506     }
00507 
00508 
00509     template<class U, class C>
00510     BT::INT next(MultiArrayView<2,U,C> const & feature) const
00511     {
00512         double result = -1 * intercept();
00513         if(*(BT::column_data()) == AllColumns)
00514         {
00515             for(int ii = 0; ii < BT::columns_size(); ++ii)
00516             {
00517                 result +=feature[ii] * weights()[ii];
00518             }
00519         }
00520         else
00521         {
00522             for(int ii = 0; ii < BT::columns_size(); ++ii)
00523             {
00524                 result +=feature[BT::columns_begin()[ii]] * weights()[ii];
00525             }
00526         }
00527         return result < 0 ? BT::child(0)
00528                           : BT::child(1);
00529     }
00530 };
00531 
00532 
00533 
00534 template<>
00535 class Node<i_HypersphereNode>
00536 : public NodeBase
00537 {
00538     public:
00539 
00540     typedef NodeBase BT;
00541 
00542         /**constructors **/
00543 
00544     Node(           int                      nCol,
00545                     BT::T_Container_type    &   topology,
00546                     BT::P_Container_type    &   param)
00547                 :   NodeBase(nCol + 5,nCol + 1,topology, param)
00548     {
00549         BT::typeID() = i_HypersphereNode;
00550     }
00551 
00552     Node(           BT::T_Container_type  const  &   topology,
00553                     BT::P_Container_type  const  &  param,
00554                     int                  n             )
00555                 :   NodeBase(5, 1,topology, param, n)
00556     {
00557         BT::topology_size_ += BT::column_data()[0]== AllColumns ?
00558                                         0
00559                                     :   BT::column_data()[0];
00560         BT::parameter_size_ += BT::columns_size();
00561     }
00562 
00563     Node( BT & node_)
00564         :   BT(5, 1, node_) 
00565     {
00566         BT::topology_size_ += BT::column_data()[0]== AllColumns ?
00567                                         0
00568                                     :   BT::column_data()[0];
00569         BT::parameter_size_ += BT::columns_size();
00570 
00571     }
00572 
00573     double const & squaredRadius() const
00574     {
00575         return BT::parameters_begin()[1];
00576     }
00577 
00578     double& squaredRadius()
00579     {
00580         return BT::parameters_begin()[1];
00581     }
00582 
00583     BT::Parameter_type center() const
00584     {
00585         return BT::parameters_begin()+2;
00586     }
00587 
00588     BT::Parameter_type center()
00589     {
00590         return BT::parameters_begin()+2;
00591     }
00592 
00593     template<class U, class C>
00594     BT::INT next(MultiArrayView<2,U,C> const & feature) const
00595     {
00596         double result = -1 * squaredRadius();
00597         if(*(BT::column_data()) == AllColumns)
00598         {
00599             for(int ii = 0; ii < BT::columns_size(); ++ii)
00600             {
00601                 result += (feature[ii] - center()[ii])*
00602                           (feature[ii] - center()[ii]);
00603             }
00604         }
00605         else
00606         {
00607             for(int ii = 0; ii < BT::columns_size(); ++ii)
00608             {
00609                 result += (feature[BT::columns_begin()[ii]] - center()[ii])*
00610                           (feature[BT::columns_begin()[ii]] - center()[ii]);
00611             }
00612         }
00613         return result < 0 ? BT::child(0)
00614                           : BT::child(1);
00615     }
00616 };
00617 
00618 
00619 /** ExteriorNodeBase class.
00620 
00621     \ingroup DecicionTree
00622 
00623     This class implements common features of all interior nodes.
00624     All interior nodes are derived classes of ExteriorNodeBase.
00625 */
00626 
00627 
00628 
00629 
00630 
00631 
00632 template<>
00633 class Node<e_ConstProbNode>
00634 : public NodeBase
00635 {
00636     public:
00637 
00638     typedef     NodeBase    BT;
00639 
00640     Node(           BT::T_Container_type    &   topology,
00641                     BT::P_Container_type    &   param)
00642                     :
00643                 BT(2,topology[1]+1, topology, param)
00644 
00645     {
00646         BT::typeID() = e_ConstProbNode;
00647     }
00648 
00649 
00650     Node(           BT::T_Container_type const &   topology,
00651                     BT::P_Container_type const &   param,
00652                     int                  n             )
00653                 :   BT(2, topology[1]+1,topology, param, n)
00654     { }
00655 
00656 
00657     Node( BT & node_)
00658         :   BT(2, node_.classCount_ +1, node_) 
00659     {}
00660     BT::Parameter_type  prob_begin() const
00661     {
00662         return BT::parameters_begin()+1;
00663     }
00664     BT::Parameter_type  prob_end() const
00665     {
00666         return prob_begin() + prob_size();
00667     }
00668     int prob_size() const
00669     {
00670         return BT::classCount_;
00671     }
00672 };
00673 
00674 template<>
00675 class Node<e_LogRegProbNode>;
00676 
00677 } // namespace vigra
00678 
00679 #endif //RF_nodeproxy

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.0 (Thu Aug 25 2011)