[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_common.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 
00037 #ifndef VIGRA_RF_COMMON_HXX
00038 #define VIGRA_RF_COMMON_HXX
00039 
00040 namespace vigra
00041 {
00042 
00043 // FORWARD DECLARATIONS
00044 // TODO : DECIDE WHETHER THIS IS A GOOD IDEA
00045 struct                  ClassificationTag{};
00046 
00047 struct                  RegressionTag{};
00048 
00049 class GiniCriterion;
00050 
00051 template<class T>
00052 class BestGiniOfColumn;
00053 
00054 template<class T, class U = ClassificationTag>
00055 class ThresholdSplit;
00056 
00057 typedef  ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit;
00058 
00059 namespace rf
00060 {
00061     class                   StopVisiting;
00062 }
00063 
00064 class                   OOB_Visitor;
00065 class                   RandomForestOptions;
00066 
00067 template<class T= double>
00068 class                   ProblemSpec;
00069 
00070 template<class LabelT = double, class Tag = ClassificationTag>
00071 class               RandomForest;
00072 
00073 
00074 class                   EarlyStoppStd;
00075 
00076 namespace detail
00077 {
00078     class               RF_DEFAULT;
00079     class               DecisionTree;
00080 }
00081 
00082 detail::RF_DEFAULT&     rf_default();
00083 
00084 template <class T>
00085 class               DT_StackEntry;
00086 
00087 /**\brief Traits Class for the Random Forest
00088  *
00089  * refer to the typedefs in this class when using the default
00090  * objects as the names may change in future.
00091  */
00092 class RF_Traits
00093 {
00094     public:
00095     typedef RandomForestOptions     Options_t;
00096     typedef detail::DecisionTree    DecisionTree_t;
00097     typedef ClassificationTag       Preprocessor_t;
00098     typedef GiniSplit               Default_Split_t;
00099     typedef EarlyStoppStd           Default_Stop_t;
00100     typedef rf::StopVisiting
00101                                     Default_Visitor_t;
00102     typedef rf::StopVisiting            StopVisiting_t;
00103 
00104 };
00105 
00106 
00107 /**\brief Standard early stopping criterion
00108  *
00109  * Stop if region.size() < min_split_node_size_;
00110  */
00111 class EarlyStoppStd
00112 {
00113     public:
00114     int min_split_node_size_;
00115 
00116     template<class Opt>
00117     EarlyStoppStd(Opt opt)
00118     :   min_split_node_size_(opt.min_split_node_size_)
00119     {}
00120 
00121     template<class T>
00122     void set_external_parameters(ProblemSpec<T>const  &, int /* tree_count */ = 0, bool /* is_weighted */ = false)
00123     {}
00124 
00125     template<class Region>
00126     bool operator()(Region& region)
00127     {
00128         return region.size() < min_split_node_size_;
00129     }
00130 
00131     template<class WeightIter, class T, class C>
00132     bool after_prediction(WeightIter,  int /* k */, MultiArrayView<2, T, C> /* prob */, double /* totalCt */)
00133     {
00134         return false; 
00135     }
00136 };
00137 
00138 
00139 
00140 
00141 namespace detail
00142 {
00143 
00144 /**\brief singleton default tag class -
00145  *
00146  *  use the rf_default() factory function to use the tag.
00147  *  \sa RandomForest<>::learn();
00148  */
00149 class RF_DEFAULT
00150 {
00151     private:
00152         RF_DEFAULT()
00153         {}
00154     public:
00155         friend RF_DEFAULT& ::vigra::rf_default();
00156 
00157         /** ok workaround for automatic choice of the decisiontree
00158          * stackentry.
00159          */
00160         typedef DT_StackEntry<ArrayVectorView<Int32>::iterator>
00161                     StackEntry_t;
00162 };
00163 
00164 /**\brief chooses between default type and type supplied
00165  * 
00166  * This is an internal class and you shouldn't really care about it.
00167  * Just pass on used in RandomForest.learn()
00168  * Usage:
00169  *\code
00170  *      // example: use container type supplied by user or ArrayVector if 
00171  *      //          rf_default() was specified as argument;
00172  *      template<class Container_t>
00173  *      void do_some_foo(Container_t in)
00174  *      {
00175  *          typedef ArrayVector<int>    Default_Container_t;
00176  *          Default_Container_t         default_value;
00177  *          Value_Chooser<Container_t,  Default_Container_t> 
00178  *                      choose(in, default_value);
00179  *
00180  *          // if the user didn't care and the in was of type 
00181  *          // RF_DEFAULT then default_value is used.
00182  *          do_some_more_foo(choose.value());
00183  *      }
00184  *      Value_Chooser choose_val<Type, Default_Type>
00185  *\endcode
00186  */
00187 template<class T, class C>
00188 class Value_Chooser
00189 {
00190 public:
00191     typedef T type;
00192     static T & choose(T & t, C &)
00193     {
00194         return t; 
00195     }
00196 };
00197 
00198 template<class C>
00199 class Value_Chooser<detail::RF_DEFAULT, C>
00200 {
00201 public:
00202     typedef C type;
00203     
00204     static C & choose(detail::RF_DEFAULT &, C & c)
00205     {
00206         return c; 
00207     }
00208 };
00209 
00210 
00211 
00212 
00213 } //namespace detail
00214 
00215 
00216 /**\brief factory function to return a RF_DEFAULT tag
00217  * \sa RandomForest<>::learn()
00218  */
00219 detail::RF_DEFAULT& rf_default()
00220 {
00221     static detail::RF_DEFAULT result;
00222     return result;
00223 }
00224 
00225 /** tags used with the RandomForestOptions class
00226  * \sa RF_Traits::Option_t
00227  */
00228 enum RF_OptionTag   { RF_EQUAL,
00229                       RF_PROPORTIONAL,
00230                       RF_EXTERNAL,
00231                       RF_NONE,
00232                       RF_FUNCTION,
00233                       RF_LOG,
00234                       RF_SQRT,
00235                       RF_CONST,
00236                       RF_ALL};
00237 
00238 
00239 /**\brief Options object for the random forest
00240  *
00241  * usage:
00242  * RandomForestOptions a =  RandomForestOptions()
00243  *                              .param1(value1)
00244  *                              .param2(value2)
00245  *                              ...
00246  *
00247  * This class only contains options/parameters that are not problem
00248  * dependent. The ProblemSpec class contains methods to set class weights
00249  * if necessary.
00250  *
00251  * Note that the return value of all methods is *this which makes
00252  * concatenating of options as above possible.
00253  */
00254 class RandomForestOptions
00255 {
00256   public:
00257     /**\name sampling options*/
00258     /*\{*/
00259     // look at the member access functions for documentation
00260     double  training_set_proportion_;
00261     int     training_set_size_;
00262     int (*training_set_func_)(int);
00263     RF_OptionTag
00264         training_set_calc_switch_;
00265 
00266     bool    sample_with_replacement_;
00267     RF_OptionTag
00268             stratification_method_;
00269 
00270 
00271     /**\name general random forest options
00272      *
00273      * these usually will be used by most split functors and
00274      * stopping predicates
00275      */
00276     /*\{*/
00277     RF_OptionTag    mtry_switch_;
00278     int     mtry_;
00279     int (*mtry_func_)(int) ;
00280 
00281     bool predict_weighted_; 
00282     int tree_count_;
00283     int min_split_node_size_;
00284     bool prepare_online_learning_;
00285     /*\}*/
00286 
00287     int serialized_size() const
00288     {
00289         return 12;
00290     }
00291     
00292 
00293     bool operator==(RandomForestOptions & rhs) const
00294     {
00295         bool result = true;
00296         #define COMPARE(field) result = result && (this->field == rhs.field); 
00297         COMPARE(training_set_proportion_);
00298         COMPARE(training_set_size_);
00299         COMPARE(training_set_calc_switch_);
00300         COMPARE(sample_with_replacement_);
00301         COMPARE(stratification_method_);
00302         COMPARE(mtry_switch_);
00303         COMPARE(mtry_);
00304         COMPARE(tree_count_);
00305         COMPARE(min_split_node_size_);
00306         COMPARE(predict_weighted_);
00307         #undef COMPARE
00308 
00309         return result;
00310     }
00311     bool operator!=(RandomForestOptions & rhs_) const
00312     {
00313         return !(*this == rhs_);
00314     }
00315     template<class Iter>
00316     void unserialize(Iter const & begin, Iter const & end)
00317     {
00318         Iter iter = begin;
00319         vigra_precondition(static_cast<size_t>(end - begin) == serialized_size(), 
00320                            "RandomForestOptions::unserialize():"
00321                            "wrong number of parameters");
00322         #define PULL(item_, type_) item_ = type_(*iter); ++iter;
00323         PULL(training_set_proportion_, double);
00324         PULL(training_set_size_, int);
00325         ++iter; //PULL(training_set_func_, double);
00326         PULL(training_set_calc_switch_, (RF_OptionTag)int);
00327         PULL(sample_with_replacement_, 0 != );
00328         PULL(stratification_method_, (RF_OptionTag)int);
00329         PULL(mtry_switch_, (RF_OptionTag)int);
00330         PULL(mtry_, int);
00331         ++iter; //PULL(mtry_func_, double);
00332         PULL(tree_count_, int);
00333         PULL(min_split_node_size_, int);
00334         PULL(predict_weighted_, 0 !=);
00335         #undef PULL
00336     }
00337     template<class Iter>
00338     void serialize(Iter const &  begin, Iter const & end) const
00339     {
00340         Iter iter = begin;
00341         vigra_precondition(static_cast<size_t>(end - begin) == serialized_size(), 
00342                            "RandomForestOptions::serialize():"
00343                            "wrong number of parameters");
00344         #define PUSH(item_) *iter = double(item_); ++iter;
00345         PUSH(training_set_proportion_);
00346         PUSH(training_set_size_);
00347         if(training_set_func_ != 0)
00348         {
00349             PUSH(1);
00350         }
00351         else
00352         {
00353             PUSH(0);
00354         }
00355         PUSH(training_set_calc_switch_);
00356         PUSH(sample_with_replacement_);
00357         PUSH(stratification_method_);
00358         PUSH(mtry_switch_);
00359         PUSH(mtry_);
00360         if(mtry_func_ != 0)
00361         {
00362             PUSH(1);
00363         }
00364         else
00365         {
00366             PUSH(0);
00367         }
00368         PUSH(tree_count_);
00369         PUSH(min_split_node_size_);
00370         PUSH(predict_weighted_);
00371         #undef PUSH
00372     }
00373 
00374 
00375     /**\brief create a RandomForestOptions object with default initialisation.
00376      *
00377      * look at the other member functions for more information on default
00378      * values
00379      */
00380     RandomForestOptions()
00381     :
00382         training_set_proportion_(1.0),
00383         training_set_size_(0),
00384         training_set_func_(0),
00385         training_set_calc_switch_(RF_PROPORTIONAL),
00386         sample_with_replacement_(true),
00387         stratification_method_(RF_NONE),
00388         mtry_switch_(RF_SQRT),
00389         mtry_(0),
00390         mtry_func_(0),
00391         predict_weighted_(false),
00392         tree_count_(256),
00393         min_split_node_size_(1),
00394         prepare_online_learning_(false)
00395     {}
00396 
00397     /**\brief specify stratification strategy
00398      *
00399      * default: RF_NONE
00400      * possible values: RF_EQUAL, RF_PROPORTIONAL,
00401      *                  RF_EXTERNAL, RF_NONE
00402      * RF_EQUAL:        get equal amount of samples per class.
00403      * RF_PROPORTIONAL: sample proportional to fraction of class samples
00404      *                  in population
00405      * RF_EXTERNAL:     strata_weights_ field of the ProblemSpec_t object
00406      *                  has been set externally. (defunct)
00407      */
00408     RandomForestOptions & use_stratification(RF_OptionTag in)
00409     {
00410         vigra_precondition(in == RF_EQUAL ||
00411                            in == RF_PROPORTIONAL ||
00412                            in == RF_EXTERNAL ||
00413                            in == RF_NONE,
00414                            "RandomForestOptions::use_stratification()"
00415                            "input must be RF_EQUAL, RF_PROPORTIONAL,"
00416                            "RF_EXTERNAL or RF_NONE");
00417         stratification_method_ = in;
00418         return *this;
00419     }
00420 
00421     RandomForestOptions & prepare_online_learning(bool in)
00422     {
00423         prepare_online_learning_=in;
00424         return *this;
00425     }
00426 
00427     /**\brief sample from training population with or without replacement?
00428      *
00429      * <br> Default: true
00430      */
00431     RandomForestOptions & sample_with_replacement(bool in)
00432     {
00433         sample_with_replacement_ = in;
00434         return *this;
00435     }
00436 
00437     /**\brief  specify the fraction of the total number of samples 
00438      * used per tree for learning. 
00439      *
00440      * This value should be in [0.0 1.0] if sampling without
00441      * replacement has been specified.
00442      *
00443      * <br> default : 1.0
00444      */
00445     RandomForestOptions & samples_per_tree(double in)
00446     {
00447         training_set_proportion_ = in;
00448         training_set_calc_switch_ = RF_PROPORTIONAL;
00449         return *this;
00450     }
00451 
00452     /**\brief directly specify the number of samples per tree
00453      */
00454     RandomForestOptions & samples_per_tree(int in)
00455     {
00456         training_set_size_ = in;
00457         training_set_calc_switch_ = RF_CONST;
00458         return *this;
00459     }
00460 
00461     /**\brief use external function to calculate the number of samples each
00462      *        tree should be learnt with.
00463      *
00464      * \param in function pointer that takes the number of rows in the
00465      *           learning data and outputs the number samples per tree.
00466      */
00467     RandomForestOptions & samples_per_tree(int (*in)(int))
00468     {
00469         training_set_func_ = in;
00470         training_set_calc_switch_ = RF_FUNCTION;
00471         return *this;
00472     }
00473     
00474     /**\brief weight each tree with number of samples in that node
00475      */
00476     RandomForestOptions & predict_weighted()
00477     {
00478         predict_weighted_ = true;
00479         return *this;
00480     }
00481 
00482     /**\brief use built in mapping to calculate mtry
00483      *
00484      * Use one of the built in mappings to calculate mtry from the number
00485      * of columns in the input feature data.
00486      * \param in possible values: RF_LOG, RF_SQRT or RF_ALL
00487      *           <br> default: RF_SQRT.
00488      */
00489     RandomForestOptions & features_per_node(RF_OptionTag in)
00490     {
00491         vigra_precondition(in == RF_LOG ||
00492                            in == RF_SQRT||
00493                            in == RF_ALL,
00494                            "RandomForestOptions()::features_per_node():"
00495                            "input must be of type RF_LOG or RF_SQRT");
00496         mtry_switch_ = in;
00497         return *this;
00498     }
00499 
00500     /**\brief Set mtry to a constant value
00501      *
00502      * mtry is the number of columns/variates/variables randomly choosen
00503      * to select the best split from.
00504      *
00505      */
00506     RandomForestOptions & features_per_node(int in)
00507     {
00508         mtry_ = in;
00509         mtry_switch_ = RF_CONST;
00510         return *this;
00511     }
00512 
00513     /**\brief use a external function to calculate mtry
00514      *
00515      * \param in function pointer that takes int (number of columns
00516      *           of the and outputs int (mtry)
00517      */
00518     RandomForestOptions & features_per_node(int(*in)(int))
00519     {
00520         mtry_func_ = in;
00521         mtry_switch_ = RF_FUNCTION;
00522         return *this;
00523     }
00524 
00525     /** How many trees to create?
00526      *
00527      * <br> Default: 255.
00528      */
00529     RandomForestOptions & tree_count(int in)
00530     {
00531         tree_count_ = in;
00532         return *this;
00533     }
00534 
00535     /**\brief Number of examples required for a node to be split.
00536      *
00537      *  When the number of examples in a node is below this number,
00538      *  the node is not split even if class separation is not yet perfect.
00539      *  Instead, the node returns the proportion of each class
00540      *  (among the remaining examples) during the prediction phase.
00541      *  <br> Default: 1 (complete growing)
00542      */
00543     RandomForestOptions & min_split_node_size(int in)
00544     {
00545         min_split_node_size_ = in;
00546         return *this;
00547     }
00548 };
00549 
00550 
00551 /** \brief problem types 
00552  */
00553 enum Problem_t{REGRESSION, CLASSIFICATION, CHECKLATER};
00554 
00555 
00556 /** \brief problem specification class for the random forest.
00557  *
00558  * This class contains all the problem specific parameters the random
00559  * forest needs for learning. Specification of an instance of this class
00560  * is optional as all necessary fields will be computed prior to learning
00561  * if not specified.
00562  *
00563  * if needed usage is similar to that of RandomForestOptions
00564  */
00565 
00566 template<class LabelType>
00567 class ProblemSpec
00568 {
00569 
00570 
00571 public:
00572 
00573     /** \brief  problem class
00574      */
00575 
00576     typedef LabelType       Label_t;
00577     ArrayVector<Label_t>    classes;
00578 
00579     int                     column_count_;
00580     int                     class_count_;
00581     int                     row_count_;
00582 
00583     int                     actual_mtry_;
00584     int                     actual_msample_;
00585 
00586     Problem_t               problem_type_;
00587     
00588     int used_;
00589     ArrayVector<double>     class_weights_;
00590     int                     is_weighted;
00591     double                  precision_;
00592     
00593         
00594     template<class T> 
00595     void to_classlabel(int index, T & out) const
00596     {
00597         out = T(classes[index]);
00598     }
00599     template<class T> 
00600     int to_classIndex(T index) const
00601     {
00602         return std::find(classes.begin(), classes.end(), index) - classes.begin();
00603     }
00604 
00605     #define EQUALS(field) field(rhs.field)
00606     ProblemSpec(ProblemSpec const & rhs)
00607     : 
00608         EQUALS(column_count_),
00609         EQUALS(class_count_),
00610         EQUALS(row_count_),
00611         EQUALS(actual_mtry_),
00612         EQUALS(actual_msample_),
00613         EQUALS(problem_type_),
00614         EQUALS(used_),
00615         EQUALS(class_weights_),
00616         EQUALS(is_weighted),
00617         EQUALS(precision_)
00618     {
00619         std::back_insert_iterator<ArrayVector<Label_t> >
00620                         iter(classes);
00621         std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 
00622     }
00623     #undef EQUALS
00624     #define EQUALS(field) field(rhs.field)
00625     template<class T>
00626     ProblemSpec(ProblemSpec<T> const & rhs)
00627     : 
00628         EQUALS(column_count_),
00629         EQUALS(class_count_),
00630         EQUALS(row_count_),
00631         EQUALS(actual_mtry_),
00632         EQUALS(actual_msample_),
00633         EQUALS(problem_type_),
00634         EQUALS(used_),
00635         EQUALS(class_weights_),
00636         EQUALS(is_weighted),
00637         EQUALS(precision_)
00638     {
00639         std::back_insert_iterator<ArrayVector<Label_t> >
00640                         iter(classes);
00641         std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 
00642     }
00643     #undef EQUALS
00644 
00645     // for some reason the function below does not match
00646     // the default copy constructor
00647     #define EQUALS(field) (this->field = rhs.field);
00648     ProblemSpec & operator=(ProblemSpec const & rhs)
00649     {
00650         EQUALS(column_count_);
00651         EQUALS(class_count_);
00652         EQUALS(row_count_);
00653         EQUALS(actual_mtry_);
00654         EQUALS(actual_msample_);
00655         EQUALS(problem_type_);
00656         EQUALS(used_);
00657         EQUALS(is_weighted);
00658         EQUALS(precision_);
00659         class_weights_.clear();
00660         std::back_insert_iterator<ArrayVector<double> >
00661                         iter2(class_weights_);
00662         std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2); 
00663         classes.clear();
00664         std::back_insert_iterator<ArrayVector<Label_t> >
00665                         iter(classes);
00666         std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 
00667         return *this;
00668     }
00669 
00670     template<class T>
00671     ProblemSpec<Label_t> & operator=(ProblemSpec<T> const & rhs)
00672     {
00673         EQUALS(column_count_);
00674         EQUALS(class_count_);
00675         EQUALS(row_count_);
00676         EQUALS(actual_mtry_);
00677         EQUALS(actual_msample_);
00678         EQUALS(problem_type_);
00679         EQUALS(used_);
00680         EQUALS(is_weighted);
00681         EQUALS(precision_);
00682         class_weights_.clear();
00683         std::back_insert_iterator<ArrayVector<double> >
00684                         iter2(class_weights_);
00685         std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2); 
00686         classes.clear();
00687         std::back_insert_iterator<ArrayVector<Label_t> >
00688                         iter(classes);
00689         std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 
00690         return *this;
00691     }
00692     #undef EQUALS
00693 
00694     template<class T>
00695     bool operator==(ProblemSpec<T> const & rhs)
00696     {
00697         bool result = true;
00698         #define COMPARE(field) result = result && (this->field == rhs.field);
00699         COMPARE(column_count_);
00700         COMPARE(class_count_);
00701         COMPARE(row_count_);
00702         COMPARE(actual_mtry_);
00703         COMPARE(actual_msample_);
00704         COMPARE(problem_type_);
00705         COMPARE(is_weighted);
00706         COMPARE(precision_);
00707         COMPARE(used_);
00708         COMPARE(class_weights_);
00709         COMPARE(classes);
00710         #undef COMPARE
00711         return result;
00712     }
00713 
00714     bool operator!=(ProblemSpec & rhs)
00715     {
00716         return !(*this == rhs);
00717     }
00718 
00719 
00720     size_t serialized_size() const
00721     {
00722         return 9 + class_count_ *int(is_weighted+1);
00723     }
00724 
00725 
00726     template<class Iter>
00727     void unserialize(Iter const & begin, Iter const & end)
00728     {
00729         Iter iter = begin;
00730         vigra_precondition(end - begin >= 9, 
00731                            "ProblemSpec::unserialize():"
00732                            "wrong number of parameters");
00733         #define PULL(item_, type_) item_ = type_(*iter); ++iter;
00734         PULL(column_count_,int);
00735         PULL(class_count_, int);
00736 
00737         vigra_precondition(end - begin >= 9 + class_count_, 
00738                            "ProblemSpec::unserialize(): 1");
00739         PULL(row_count_, int);
00740         PULL(actual_mtry_,int);
00741         PULL(actual_msample_, int);
00742         PULL(problem_type_, Problem_t);
00743         PULL(is_weighted, int);
00744         PULL(used_, int);
00745         PULL(precision_, double);
00746         if(is_weighted)
00747         {
00748             vigra_precondition(end - begin == 9 + 2*class_count_, 
00749                                "ProblemSpec::unserialize(): 2");
00750             class_weights_.insert(class_weights_.end(),
00751                                   iter, 
00752                                   iter + class_count_);
00753             iter += class_count_; 
00754         }
00755         classes.insert(classes.end(), iter, end);
00756         #undef PULL
00757     }
00758 
00759 
00760     template<class Iter>
00761     void serialize(Iter const & begin, Iter const & end) const
00762     {
00763         Iter iter = begin;
00764         vigra_precondition(end - begin == serialized_size(), 
00765                            "RandomForestOptions::serialize():"
00766                            "wrong number of parameters");
00767         #define PUSH(item_) *iter = double(item_); ++iter;
00768         PUSH(column_count_);
00769         PUSH(class_count_)
00770         PUSH(row_count_);
00771         PUSH(actual_mtry_);
00772         PUSH(actual_msample_);
00773         PUSH(problem_type_);
00774         PUSH(is_weighted);
00775         PUSH(used_);
00776         PUSH(precision_);
00777         if(is_weighted)
00778         {
00779             std::copy(class_weights_.begin(),
00780                       class_weights_.end(),
00781                       iter);
00782             iter += class_count_; 
00783         }
00784         std::copy(classes.begin(),
00785                   classes.end(),
00786                   iter);
00787         #undef PUSH
00788     }
00789 
00790     void make_from_map(std::map<std::string, ArrayVector<double> > & in)
00791     {
00792         typedef MultiArrayShape<2>::type Shp; 
00793         #define PULL(item_, type_) item_ = type_(in[#item_][0]); 
00794         PULL(column_count_,int);
00795         PULL(class_count_, int);
00796         PULL(row_count_, int);
00797         PULL(actual_mtry_,int);
00798         PULL(actual_msample_, int);
00799         PULL(problem_type_, (Problem_t)int);
00800         PULL(is_weighted, int);
00801         PULL(used_, int);
00802         PULL(precision_, double);
00803         class_weights_ = in["class_weights_"];
00804         #undef PUSH
00805     }
00806     void make_map(std::map<std::string, ArrayVector<double> > & in) const
00807     {
00808         typedef MultiArrayShape<2>::type Shp; 
00809         #define PUSH(item_) in[#item_] = ArrayVector<double>(1, double(item_)); 
00810         PUSH(column_count_);
00811         PUSH(class_count_)
00812         PUSH(row_count_);
00813         PUSH(actual_mtry_);
00814         PUSH(actual_msample_);
00815         PUSH(problem_type_);
00816         PUSH(is_weighted);
00817         PUSH(used_);
00818         PUSH(precision_);
00819         in["class_weights_"] = class_weights_;
00820         #undef PUSH
00821     }
00822     
00823     /**\brief set default values (-> values not set)
00824      */
00825     ProblemSpec()
00826     :   column_count_(0),
00827         class_count_(0),
00828         row_count_(0),
00829         actual_mtry_(0),
00830         actual_msample_(0),
00831         problem_type_(CHECKLATER),
00832         used_(false),
00833         is_weighted(false),
00834         precision_(0.0)
00835     {}
00836 
00837 
00838     ProblemSpec & column_count(int in)
00839     {
00840         column_count_ = in;
00841         return *this;
00842     }
00843 
00844     /**\brief supply with class labels -
00845      * 
00846      * the preprocessor will not calculate the labels needed in this case.
00847      */
00848     template<class C_Iter>
00849     ProblemSpec & classes_(C_Iter begin, C_Iter end)
00850     {
00851         int size = end-begin;
00852         for(int k=0; k<size; ++k, ++begin)
00853             classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
00854         class_count_ = size;
00855         return *this;
00856     }
00857 
00858     /** \brief supply with class weights  -
00859      *
00860      * this is the only case where you would really have to 
00861      * create a ProblemSpec object.
00862      */
00863     template<class W_Iter>
00864     ProblemSpec & class_weights(W_Iter begin, W_Iter end)
00865     {
00866         class_weights_.insert(class_weights_.end(), begin, end);
00867         is_weighted = true;
00868         return *this;
00869     }
00870 
00871 
00872 
00873     void clear()
00874     {
00875         used_ = false; 
00876         classes.clear();
00877         class_weights_.clear();
00878         column_count_ = 0 ;
00879         class_count_ = 0;
00880         actual_mtry_ = 0;
00881         actual_msample_ = 0;
00882         problem_type_ = CHECKLATER;
00883         is_weighted = false;
00884         precision_   = 0.0;
00885 
00886     }
00887 
00888     bool used() const
00889     {
00890         return used_ != 0;
00891     }
00892 };
00893 
00894 
00895 
00896 } // namespace vigra
00897 
00898 #endif //VIGRA_RF_COMMON_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.0 (Thu Aug 25 2011)