[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 00037 #ifndef VIGRA_RF_COMMON_HXX 00038 #define VIGRA_RF_COMMON_HXX 00039 00040 namespace vigra 00041 { 00042 00043 // FORWARD DECLARATIONS 00044 // TODO : DECIDE WHETHER THIS IS A GOOD IDEA 00045 struct ClassificationTag{}; 00046 00047 struct RegressionTag{}; 00048 00049 class GiniCriterion; 00050 00051 template<class T> 00052 class BestGiniOfColumn; 00053 00054 template<class T, class U = ClassificationTag> 00055 class ThresholdSplit; 00056 00057 typedef ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit; 00058 00059 namespace rf 00060 { 00061 class StopVisiting; 00062 } 00063 00064 class OOB_Visitor; 00065 class RandomForestOptions; 00066 00067 template<class T= double> 00068 class ProblemSpec; 00069 00070 template<class LabelT = double, class Tag = ClassificationTag> 00071 class RandomForest; 00072 00073 00074 class EarlyStoppStd; 00075 00076 namespace detail 00077 { 00078 class RF_DEFAULT; 00079 class DecisionTree; 00080 } 00081 00082 detail::RF_DEFAULT& rf_default(); 00083 00084 template <class T> 00085 class DT_StackEntry; 00086 00087 /**\brief Traits Class for the Random Forest 00088 * 00089 * refer to the typedefs in this class when using the default 00090 * objects as the names may change in future. 00091 */ 00092 class RF_Traits 00093 { 00094 public: 00095 typedef RandomForestOptions Options_t; 00096 typedef detail::DecisionTree DecisionTree_t; 00097 typedef ClassificationTag Preprocessor_t; 00098 typedef GiniSplit Default_Split_t; 00099 typedef EarlyStoppStd Default_Stop_t; 00100 typedef rf::StopVisiting 00101 Default_Visitor_t; 00102 typedef rf::StopVisiting StopVisiting_t; 00103 00104 }; 00105 00106 00107 /**\brief Standard early stopping criterion 00108 * 00109 * Stop if region.size() < min_split_node_size_; 00110 */ 00111 class EarlyStoppStd 00112 { 00113 public: 00114 int min_split_node_size_; 00115 00116 template<class Opt> 00117 EarlyStoppStd(Opt opt) 00118 : min_split_node_size_(opt.min_split_node_size_) 00119 {} 00120 00121 template<class T> 00122 void set_external_parameters(ProblemSpec<T>const &, int /* tree_count */ = 0, bool /* is_weighted */ = false) 00123 {} 00124 00125 template<class Region> 00126 bool operator()(Region& region) 00127 { 00128 return region.size() < min_split_node_size_; 00129 } 00130 00131 template<class WeightIter, class T, class C> 00132 bool after_prediction(WeightIter, int /* k */, MultiArrayView<2, T, C> /* prob */, double /* totalCt */) 00133 { 00134 return false; 00135 } 00136 }; 00137 00138 00139 00140 00141 namespace detail 00142 { 00143 00144 /**\brief singleton default tag class - 00145 * 00146 * use the rf_default() factory function to use the tag. 00147 * \sa RandomForest<>::learn(); 00148 */ 00149 class RF_DEFAULT 00150 { 00151 private: 00152 RF_DEFAULT() 00153 {} 00154 public: 00155 friend RF_DEFAULT& ::vigra::rf_default(); 00156 00157 /** ok workaround for automatic choice of the decisiontree 00158 * stackentry. 00159 */ 00160 typedef DT_StackEntry<ArrayVectorView<Int32>::iterator> 00161 StackEntry_t; 00162 }; 00163 00164 /**\brief chooses between default type and type supplied 00165 * 00166 * This is an internal class and you shouldn't really care about it. 00167 * Just pass on used in RandomForest.learn() 00168 * Usage: 00169 *\code 00170 * // example: use container type supplied by user or ArrayVector if 00171 * // rf_default() was specified as argument; 00172 * template<class Container_t> 00173 * void do_some_foo(Container_t in) 00174 * { 00175 * typedef ArrayVector<int> Default_Container_t; 00176 * Default_Container_t default_value; 00177 * Value_Chooser<Container_t, Default_Container_t> 00178 * choose(in, default_value); 00179 * 00180 * // if the user didn't care and the in was of type 00181 * // RF_DEFAULT then default_value is used. 00182 * do_some_more_foo(choose.value()); 00183 * } 00184 * Value_Chooser choose_val<Type, Default_Type> 00185 *\endcode 00186 */ 00187 template<class T, class C> 00188 class Value_Chooser 00189 { 00190 public: 00191 typedef T type; 00192 static T & choose(T & t, C &) 00193 { 00194 return t; 00195 } 00196 }; 00197 00198 template<class C> 00199 class Value_Chooser<detail::RF_DEFAULT, C> 00200 { 00201 public: 00202 typedef C type; 00203 00204 static C & choose(detail::RF_DEFAULT &, C & c) 00205 { 00206 return c; 00207 } 00208 }; 00209 00210 00211 00212 00213 } //namespace detail 00214 00215 00216 /**\brief factory function to return a RF_DEFAULT tag 00217 * \sa RandomForest<>::learn() 00218 */ 00219 detail::RF_DEFAULT& rf_default() 00220 { 00221 static detail::RF_DEFAULT result; 00222 return result; 00223 } 00224 00225 /** tags used with the RandomForestOptions class 00226 * \sa RF_Traits::Option_t 00227 */ 00228 enum RF_OptionTag { RF_EQUAL, 00229 RF_PROPORTIONAL, 00230 RF_EXTERNAL, 00231 RF_NONE, 00232 RF_FUNCTION, 00233 RF_LOG, 00234 RF_SQRT, 00235 RF_CONST, 00236 RF_ALL}; 00237 00238 00239 /**\brief Options object for the random forest 00240 * 00241 * usage: 00242 * RandomForestOptions a = RandomForestOptions() 00243 * .param1(value1) 00244 * .param2(value2) 00245 * ... 00246 * 00247 * This class only contains options/parameters that are not problem 00248 * dependent. The ProblemSpec class contains methods to set class weights 00249 * if necessary. 00250 * 00251 * Note that the return value of all methods is *this which makes 00252 * concatenating of options as above possible. 00253 */ 00254 class RandomForestOptions 00255 { 00256 public: 00257 /**\name sampling options*/ 00258 /*\{*/ 00259 // look at the member access functions for documentation 00260 double training_set_proportion_; 00261 int training_set_size_; 00262 int (*training_set_func_)(int); 00263 RF_OptionTag 00264 training_set_calc_switch_; 00265 00266 bool sample_with_replacement_; 00267 RF_OptionTag 00268 stratification_method_; 00269 00270 00271 /**\name general random forest options 00272 * 00273 * these usually will be used by most split functors and 00274 * stopping predicates 00275 */ 00276 /*\{*/ 00277 RF_OptionTag mtry_switch_; 00278 int mtry_; 00279 int (*mtry_func_)(int) ; 00280 00281 bool predict_weighted_; 00282 int tree_count_; 00283 int min_split_node_size_; 00284 bool prepare_online_learning_; 00285 /*\}*/ 00286 00287 int serialized_size() const 00288 { 00289 return 12; 00290 } 00291 00292 00293 bool operator==(RandomForestOptions & rhs) const 00294 { 00295 bool result = true; 00296 #define COMPARE(field) result = result && (this->field == rhs.field); 00297 COMPARE(training_set_proportion_); 00298 COMPARE(training_set_size_); 00299 COMPARE(training_set_calc_switch_); 00300 COMPARE(sample_with_replacement_); 00301 COMPARE(stratification_method_); 00302 COMPARE(mtry_switch_); 00303 COMPARE(mtry_); 00304 COMPARE(tree_count_); 00305 COMPARE(min_split_node_size_); 00306 COMPARE(predict_weighted_); 00307 #undef COMPARE 00308 00309 return result; 00310 } 00311 bool operator!=(RandomForestOptions & rhs_) const 00312 { 00313 return !(*this == rhs_); 00314 } 00315 template<class Iter> 00316 void unserialize(Iter const & begin, Iter const & end) 00317 { 00318 Iter iter = begin; 00319 vigra_precondition(static_cast<size_t>(end - begin) == serialized_size(), 00320 "RandomForestOptions::unserialize():" 00321 "wrong number of parameters"); 00322 #define PULL(item_, type_) item_ = type_(*iter); ++iter; 00323 PULL(training_set_proportion_, double); 00324 PULL(training_set_size_, int); 00325 ++iter; //PULL(training_set_func_, double); 00326 PULL(training_set_calc_switch_, (RF_OptionTag)int); 00327 PULL(sample_with_replacement_, 0 != ); 00328 PULL(stratification_method_, (RF_OptionTag)int); 00329 PULL(mtry_switch_, (RF_OptionTag)int); 00330 PULL(mtry_, int); 00331 ++iter; //PULL(mtry_func_, double); 00332 PULL(tree_count_, int); 00333 PULL(min_split_node_size_, int); 00334 PULL(predict_weighted_, 0 !=); 00335 #undef PULL 00336 } 00337 template<class Iter> 00338 void serialize(Iter const & begin, Iter const & end) const 00339 { 00340 Iter iter = begin; 00341 vigra_precondition(static_cast<size_t>(end - begin) == serialized_size(), 00342 "RandomForestOptions::serialize():" 00343 "wrong number of parameters"); 00344 #define PUSH(item_) *iter = double(item_); ++iter; 00345 PUSH(training_set_proportion_); 00346 PUSH(training_set_size_); 00347 if(training_set_func_ != 0) 00348 { 00349 PUSH(1); 00350 } 00351 else 00352 { 00353 PUSH(0); 00354 } 00355 PUSH(training_set_calc_switch_); 00356 PUSH(sample_with_replacement_); 00357 PUSH(stratification_method_); 00358 PUSH(mtry_switch_); 00359 PUSH(mtry_); 00360 if(mtry_func_ != 0) 00361 { 00362 PUSH(1); 00363 } 00364 else 00365 { 00366 PUSH(0); 00367 } 00368 PUSH(tree_count_); 00369 PUSH(min_split_node_size_); 00370 PUSH(predict_weighted_); 00371 #undef PUSH 00372 } 00373 00374 00375 /**\brief create a RandomForestOptions object with default initialisation. 00376 * 00377 * look at the other member functions for more information on default 00378 * values 00379 */ 00380 RandomForestOptions() 00381 : 00382 training_set_proportion_(1.0), 00383 training_set_size_(0), 00384 training_set_func_(0), 00385 training_set_calc_switch_(RF_PROPORTIONAL), 00386 sample_with_replacement_(true), 00387 stratification_method_(RF_NONE), 00388 mtry_switch_(RF_SQRT), 00389 mtry_(0), 00390 mtry_func_(0), 00391 predict_weighted_(false), 00392 tree_count_(256), 00393 min_split_node_size_(1), 00394 prepare_online_learning_(false) 00395 {} 00396 00397 /**\brief specify stratification strategy 00398 * 00399 * default: RF_NONE 00400 * possible values: RF_EQUAL, RF_PROPORTIONAL, 00401 * RF_EXTERNAL, RF_NONE 00402 * RF_EQUAL: get equal amount of samples per class. 00403 * RF_PROPORTIONAL: sample proportional to fraction of class samples 00404 * in population 00405 * RF_EXTERNAL: strata_weights_ field of the ProblemSpec_t object 00406 * has been set externally. (defunct) 00407 */ 00408 RandomForestOptions & use_stratification(RF_OptionTag in) 00409 { 00410 vigra_precondition(in == RF_EQUAL || 00411 in == RF_PROPORTIONAL || 00412 in == RF_EXTERNAL || 00413 in == RF_NONE, 00414 "RandomForestOptions::use_stratification()" 00415 "input must be RF_EQUAL, RF_PROPORTIONAL," 00416 "RF_EXTERNAL or RF_NONE"); 00417 stratification_method_ = in; 00418 return *this; 00419 } 00420 00421 RandomForestOptions & prepare_online_learning(bool in) 00422 { 00423 prepare_online_learning_=in; 00424 return *this; 00425 } 00426 00427 /**\brief sample from training population with or without replacement? 00428 * 00429 * <br> Default: true 00430 */ 00431 RandomForestOptions & sample_with_replacement(bool in) 00432 { 00433 sample_with_replacement_ = in; 00434 return *this; 00435 } 00436 00437 /**\brief specify the fraction of the total number of samples 00438 * used per tree for learning. 00439 * 00440 * This value should be in [0.0 1.0] if sampling without 00441 * replacement has been specified. 00442 * 00443 * <br> default : 1.0 00444 */ 00445 RandomForestOptions & samples_per_tree(double in) 00446 { 00447 training_set_proportion_ = in; 00448 training_set_calc_switch_ = RF_PROPORTIONAL; 00449 return *this; 00450 } 00451 00452 /**\brief directly specify the number of samples per tree 00453 */ 00454 RandomForestOptions & samples_per_tree(int in) 00455 { 00456 training_set_size_ = in; 00457 training_set_calc_switch_ = RF_CONST; 00458 return *this; 00459 } 00460 00461 /**\brief use external function to calculate the number of samples each 00462 * tree should be learnt with. 00463 * 00464 * \param in function pointer that takes the number of rows in the 00465 * learning data and outputs the number samples per tree. 00466 */ 00467 RandomForestOptions & samples_per_tree(int (*in)(int)) 00468 { 00469 training_set_func_ = in; 00470 training_set_calc_switch_ = RF_FUNCTION; 00471 return *this; 00472 } 00473 00474 /**\brief weight each tree with number of samples in that node 00475 */ 00476 RandomForestOptions & predict_weighted() 00477 { 00478 predict_weighted_ = true; 00479 return *this; 00480 } 00481 00482 /**\brief use built in mapping to calculate mtry 00483 * 00484 * Use one of the built in mappings to calculate mtry from the number 00485 * of columns in the input feature data. 00486 * \param in possible values: RF_LOG, RF_SQRT or RF_ALL 00487 * <br> default: RF_SQRT. 00488 */ 00489 RandomForestOptions & features_per_node(RF_OptionTag in) 00490 { 00491 vigra_precondition(in == RF_LOG || 00492 in == RF_SQRT|| 00493 in == RF_ALL, 00494 "RandomForestOptions()::features_per_node():" 00495 "input must be of type RF_LOG or RF_SQRT"); 00496 mtry_switch_ = in; 00497 return *this; 00498 } 00499 00500 /**\brief Set mtry to a constant value 00501 * 00502 * mtry is the number of columns/variates/variables randomly choosen 00503 * to select the best split from. 00504 * 00505 */ 00506 RandomForestOptions & features_per_node(int in) 00507 { 00508 mtry_ = in; 00509 mtry_switch_ = RF_CONST; 00510 return *this; 00511 } 00512 00513 /**\brief use a external function to calculate mtry 00514 * 00515 * \param in function pointer that takes int (number of columns 00516 * of the and outputs int (mtry) 00517 */ 00518 RandomForestOptions & features_per_node(int(*in)(int)) 00519 { 00520 mtry_func_ = in; 00521 mtry_switch_ = RF_FUNCTION; 00522 return *this; 00523 } 00524 00525 /** How many trees to create? 00526 * 00527 * <br> Default: 255. 00528 */ 00529 RandomForestOptions & tree_count(int in) 00530 { 00531 tree_count_ = in; 00532 return *this; 00533 } 00534 00535 /**\brief Number of examples required for a node to be split. 00536 * 00537 * When the number of examples in a node is below this number, 00538 * the node is not split even if class separation is not yet perfect. 00539 * Instead, the node returns the proportion of each class 00540 * (among the remaining examples) during the prediction phase. 00541 * <br> Default: 1 (complete growing) 00542 */ 00543 RandomForestOptions & min_split_node_size(int in) 00544 { 00545 min_split_node_size_ = in; 00546 return *this; 00547 } 00548 }; 00549 00550 00551 /** \brief problem types 00552 */ 00553 enum Problem_t{REGRESSION, CLASSIFICATION, CHECKLATER}; 00554 00555 00556 /** \brief problem specification class for the random forest. 00557 * 00558 * This class contains all the problem specific parameters the random 00559 * forest needs for learning. Specification of an instance of this class 00560 * is optional as all necessary fields will be computed prior to learning 00561 * if not specified. 00562 * 00563 * if needed usage is similar to that of RandomForestOptions 00564 */ 00565 00566 template<class LabelType> 00567 class ProblemSpec 00568 { 00569 00570 00571 public: 00572 00573 /** \brief problem class 00574 */ 00575 00576 typedef LabelType Label_t; 00577 ArrayVector<Label_t> classes; 00578 00579 int column_count_; 00580 int class_count_; 00581 int row_count_; 00582 00583 int actual_mtry_; 00584 int actual_msample_; 00585 00586 Problem_t problem_type_; 00587 00588 int used_; 00589 ArrayVector<double> class_weights_; 00590 int is_weighted; 00591 double precision_; 00592 00593 00594 template<class T> 00595 void to_classlabel(int index, T & out) const 00596 { 00597 out = T(classes[index]); 00598 } 00599 template<class T> 00600 int to_classIndex(T index) const 00601 { 00602 return std::find(classes.begin(), classes.end(), index) - classes.begin(); 00603 } 00604 00605 #define EQUALS(field) field(rhs.field) 00606 ProblemSpec(ProblemSpec const & rhs) 00607 : 00608 EQUALS(column_count_), 00609 EQUALS(class_count_), 00610 EQUALS(row_count_), 00611 EQUALS(actual_mtry_), 00612 EQUALS(actual_msample_), 00613 EQUALS(problem_type_), 00614 EQUALS(used_), 00615 EQUALS(class_weights_), 00616 EQUALS(is_weighted), 00617 EQUALS(precision_) 00618 { 00619 std::back_insert_iterator<ArrayVector<Label_t> > 00620 iter(classes); 00621 std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 00622 } 00623 #undef EQUALS 00624 #define EQUALS(field) field(rhs.field) 00625 template<class T> 00626 ProblemSpec(ProblemSpec<T> const & rhs) 00627 : 00628 EQUALS(column_count_), 00629 EQUALS(class_count_), 00630 EQUALS(row_count_), 00631 EQUALS(actual_mtry_), 00632 EQUALS(actual_msample_), 00633 EQUALS(problem_type_), 00634 EQUALS(used_), 00635 EQUALS(class_weights_), 00636 EQUALS(is_weighted), 00637 EQUALS(precision_) 00638 { 00639 std::back_insert_iterator<ArrayVector<Label_t> > 00640 iter(classes); 00641 std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 00642 } 00643 #undef EQUALS 00644 00645 // for some reason the function below does not match 00646 // the default copy constructor 00647 #define EQUALS(field) (this->field = rhs.field); 00648 ProblemSpec & operator=(ProblemSpec const & rhs) 00649 { 00650 EQUALS(column_count_); 00651 EQUALS(class_count_); 00652 EQUALS(row_count_); 00653 EQUALS(actual_mtry_); 00654 EQUALS(actual_msample_); 00655 EQUALS(problem_type_); 00656 EQUALS(used_); 00657 EQUALS(is_weighted); 00658 EQUALS(precision_); 00659 class_weights_.clear(); 00660 std::back_insert_iterator<ArrayVector<double> > 00661 iter2(class_weights_); 00662 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2); 00663 classes.clear(); 00664 std::back_insert_iterator<ArrayVector<Label_t> > 00665 iter(classes); 00666 std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 00667 return *this; 00668 } 00669 00670 template<class T> 00671 ProblemSpec<Label_t> & operator=(ProblemSpec<T> const & rhs) 00672 { 00673 EQUALS(column_count_); 00674 EQUALS(class_count_); 00675 EQUALS(row_count_); 00676 EQUALS(actual_mtry_); 00677 EQUALS(actual_msample_); 00678 EQUALS(problem_type_); 00679 EQUALS(used_); 00680 EQUALS(is_weighted); 00681 EQUALS(precision_); 00682 class_weights_.clear(); 00683 std::back_insert_iterator<ArrayVector<double> > 00684 iter2(class_weights_); 00685 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2); 00686 classes.clear(); 00687 std::back_insert_iterator<ArrayVector<Label_t> > 00688 iter(classes); 00689 std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 00690 return *this; 00691 } 00692 #undef EQUALS 00693 00694 template<class T> 00695 bool operator==(ProblemSpec<T> const & rhs) 00696 { 00697 bool result = true; 00698 #define COMPARE(field) result = result && (this->field == rhs.field); 00699 COMPARE(column_count_); 00700 COMPARE(class_count_); 00701 COMPARE(row_count_); 00702 COMPARE(actual_mtry_); 00703 COMPARE(actual_msample_); 00704 COMPARE(problem_type_); 00705 COMPARE(is_weighted); 00706 COMPARE(precision_); 00707 COMPARE(used_); 00708 COMPARE(class_weights_); 00709 COMPARE(classes); 00710 #undef COMPARE 00711 return result; 00712 } 00713 00714 bool operator!=(ProblemSpec & rhs) 00715 { 00716 return !(*this == rhs); 00717 } 00718 00719 00720 size_t serialized_size() const 00721 { 00722 return 9 + class_count_ *int(is_weighted+1); 00723 } 00724 00725 00726 template<class Iter> 00727 void unserialize(Iter const & begin, Iter const & end) 00728 { 00729 Iter iter = begin; 00730 vigra_precondition(end - begin >= 9, 00731 "ProblemSpec::unserialize():" 00732 "wrong number of parameters"); 00733 #define PULL(item_, type_) item_ = type_(*iter); ++iter; 00734 PULL(column_count_,int); 00735 PULL(class_count_, int); 00736 00737 vigra_precondition(end - begin >= 9 + class_count_, 00738 "ProblemSpec::unserialize(): 1"); 00739 PULL(row_count_, int); 00740 PULL(actual_mtry_,int); 00741 PULL(actual_msample_, int); 00742 PULL(problem_type_, Problem_t); 00743 PULL(is_weighted, int); 00744 PULL(used_, int); 00745 PULL(precision_, double); 00746 if(is_weighted) 00747 { 00748 vigra_precondition(end - begin == 9 + 2*class_count_, 00749 "ProblemSpec::unserialize(): 2"); 00750 class_weights_.insert(class_weights_.end(), 00751 iter, 00752 iter + class_count_); 00753 iter += class_count_; 00754 } 00755 classes.insert(classes.end(), iter, end); 00756 #undef PULL 00757 } 00758 00759 00760 template<class Iter> 00761 void serialize(Iter const & begin, Iter const & end) const 00762 { 00763 Iter iter = begin; 00764 vigra_precondition(end - begin == serialized_size(), 00765 "RandomForestOptions::serialize():" 00766 "wrong number of parameters"); 00767 #define PUSH(item_) *iter = double(item_); ++iter; 00768 PUSH(column_count_); 00769 PUSH(class_count_) 00770 PUSH(row_count_); 00771 PUSH(actual_mtry_); 00772 PUSH(actual_msample_); 00773 PUSH(problem_type_); 00774 PUSH(is_weighted); 00775 PUSH(used_); 00776 PUSH(precision_); 00777 if(is_weighted) 00778 { 00779 std::copy(class_weights_.begin(), 00780 class_weights_.end(), 00781 iter); 00782 iter += class_count_; 00783 } 00784 std::copy(classes.begin(), 00785 classes.end(), 00786 iter); 00787 #undef PUSH 00788 } 00789 00790 void make_from_map(std::map<std::string, ArrayVector<double> > & in) 00791 { 00792 typedef MultiArrayShape<2>::type Shp; 00793 #define PULL(item_, type_) item_ = type_(in[#item_][0]); 00794 PULL(column_count_,int); 00795 PULL(class_count_, int); 00796 PULL(row_count_, int); 00797 PULL(actual_mtry_,int); 00798 PULL(actual_msample_, int); 00799 PULL(problem_type_, (Problem_t)int); 00800 PULL(is_weighted, int); 00801 PULL(used_, int); 00802 PULL(precision_, double); 00803 class_weights_ = in["class_weights_"]; 00804 #undef PUSH 00805 } 00806 void make_map(std::map<std::string, ArrayVector<double> > & in) const 00807 { 00808 typedef MultiArrayShape<2>::type Shp; 00809 #define PUSH(item_) in[#item_] = ArrayVector<double>(1, double(item_)); 00810 PUSH(column_count_); 00811 PUSH(class_count_) 00812 PUSH(row_count_); 00813 PUSH(actual_mtry_); 00814 PUSH(actual_msample_); 00815 PUSH(problem_type_); 00816 PUSH(is_weighted); 00817 PUSH(used_); 00818 PUSH(precision_); 00819 in["class_weights_"] = class_weights_; 00820 #undef PUSH 00821 } 00822 00823 /**\brief set default values (-> values not set) 00824 */ 00825 ProblemSpec() 00826 : column_count_(0), 00827 class_count_(0), 00828 row_count_(0), 00829 actual_mtry_(0), 00830 actual_msample_(0), 00831 problem_type_(CHECKLATER), 00832 used_(false), 00833 is_weighted(false), 00834 precision_(0.0) 00835 {} 00836 00837 00838 ProblemSpec & column_count(int in) 00839 { 00840 column_count_ = in; 00841 return *this; 00842 } 00843 00844 /**\brief supply with class labels - 00845 * 00846 * the preprocessor will not calculate the labels needed in this case. 00847 */ 00848 template<class C_Iter> 00849 ProblemSpec & classes_(C_Iter begin, C_Iter end) 00850 { 00851 int size = end-begin; 00852 for(int k=0; k<size; ++k, ++begin) 00853 classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin)); 00854 class_count_ = size; 00855 return *this; 00856 } 00857 00858 /** \brief supply with class weights - 00859 * 00860 * this is the only case where you would really have to 00861 * create a ProblemSpec object. 00862 */ 00863 template<class W_Iter> 00864 ProblemSpec & class_weights(W_Iter begin, W_Iter end) 00865 { 00866 class_weights_.insert(class_weights_.end(), begin, end); 00867 is_weighted = true; 00868 return *this; 00869 } 00870 00871 00872 00873 void clear() 00874 { 00875 used_ = false; 00876 classes.clear(); 00877 class_weights_.clear(); 00878 column_count_ = 0 ; 00879 class_count_ = 0; 00880 actual_mtry_ = 0; 00881 actual_msample_ = 0; 00882 problem_type_ = CHECKLATER; 00883 is_weighted = false; 00884 precision_ = 0.0; 00885 00886 } 00887 00888 bool used() const 00889 { 00890 return used_ != 0; 00891 } 00892 }; 00893 00894 00895 00896 } // namespace vigra 00897 00898 #endif //VIGRA_RF_COMMON_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|