37 #ifndef VIGRA_RF_COMMON_HXX
38 #define VIGRA_RF_COMMON_HXX
44 struct ClassificationTag
99 template<
class T,
class C>
104 static T & choose(T & t, C &)
111 class Value_Chooser<detail::RF_DEFAULT, C>
116 static C & choose(detail::RF_DEFAULT &, C & c)
133 static detail::RF_DEFAULT result;
176 double training_set_proportion_;
177 int training_set_size_;
178 int (*training_set_func_)(int);
180 training_set_calc_switch_;
182 bool sample_with_replacement_;
184 stratification_method_;
195 int (*mtry_func_)(int) ;
197 bool predict_weighted_;
199 int min_split_node_size_;
200 bool prepare_online_learning_;
204 typedef std::map<std::string, double_array> map_type;
206 int serialized_size()
const
215 #define COMPARE(field) result = result && (this->field == rhs.field);
216 COMPARE(training_set_proportion_);
217 COMPARE(training_set_size_);
218 COMPARE(training_set_calc_switch_);
219 COMPARE(sample_with_replacement_);
220 COMPARE(stratification_method_);
221 COMPARE(mtry_switch_);
223 COMPARE(tree_count_);
224 COMPARE(min_split_node_size_);
225 COMPARE(predict_weighted_);
232 return !(*
this == rhs_);
235 void unserialize(Iter
const & begin, Iter
const & end)
238 vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
239 "RandomForestOptions::unserialize():"
240 "wrong number of parameters");
241 #define PULL(item_, type_) item_ = type_(*iter); ++iter;
242 PULL(training_set_proportion_,
double);
243 PULL(training_set_size_,
int);
246 PULL(sample_with_replacement_, 0 != );
251 PULL(tree_count_,
int);
252 PULL(min_split_node_size_,
int);
253 PULL(predict_weighted_, 0 !=);
257 void serialize(Iter
const & begin, Iter
const & end)
const
260 vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
261 "RandomForestOptions::serialize():"
262 "wrong number of parameters");
263 #define PUSH(item_) *iter = double(item_); ++iter;
264 PUSH(training_set_proportion_);
265 PUSH(training_set_size_);
266 if(training_set_func_ != 0)
274 PUSH(training_set_calc_switch_);
275 PUSH(sample_with_replacement_);
276 PUSH(stratification_method_);
288 PUSH(min_split_node_size_);
289 PUSH(predict_weighted_);
293 void make_from_map(map_type & in)
296 #define PULL(item_, type_) item_ = type_(in[#item_][0]);
297 #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0);
298 PULL(training_set_proportion_,
double);
299 PULL(training_set_size_,
int);
301 PULL(tree_count_,
int);
302 PULL(min_split_node_size_,
int);
303 PULLBOOL(sample_with_replacement_,
bool);
304 PULLBOOL(prepare_online_learning_,
bool);
305 PULLBOOL(predict_weighted_,
bool);
318 void make_map(map_type & in)
const
321 #define PUSH(item_, type_) in[#item_] = double_array(1, double(item_));
322 #define PUSHFUNC(item_, type_) in[#item_] = double_array(1, double(item_!=0));
323 PUSH(training_set_proportion_,
double);
324 PUSH(training_set_size_,
int);
326 PUSH(tree_count_,
int);
327 PUSH(min_split_node_size_,
int);
328 PUSH(sample_with_replacement_,
bool);
329 PUSH(prepare_online_learning_,
bool);
330 PUSH(predict_weighted_,
bool);
336 PUSHFUNC(mtry_func_,
int);
337 PUSHFUNC(training_set_func_,
int);
350 training_set_proportion_(1.0),
351 training_set_size_(0),
352 training_set_func_(0),
353 training_set_calc_switch_(RF_PROPORTIONAL),
354 sample_with_replacement_(true),
355 stratification_method_(RF_NONE),
356 mtry_switch_(RF_SQRT),
359 predict_weighted_(false),
361 min_split_node_size_(1),
362 prepare_online_learning_(false)
378 vigra_precondition(in == RF_EQUAL ||
379 in == RF_PROPORTIONAL ||
382 "RandomForestOptions::use_stratification()"
383 "input must be RF_EQUAL, RF_PROPORTIONAL,"
384 "RF_EXTERNAL or RF_NONE");
385 stratification_method_ = in;
391 prepare_online_learning_=in;
401 sample_with_replacement_ = in;
415 training_set_proportion_ = in;
416 training_set_calc_switch_ = RF_PROPORTIONAL;
424 training_set_size_ = in;
425 training_set_calc_switch_ = RF_CONST;
437 training_set_func_ = in;
438 training_set_calc_switch_ = RF_FUNCTION;
446 predict_weighted_ =
true;
459 vigra_precondition(in == RF_LOG ||
462 "RandomForestOptions()::features_per_node():"
463 "input must be of type RF_LOG or RF_SQRT");
477 mtry_switch_ = RF_CONST;
489 mtry_switch_ = RF_FUNCTION;
513 min_split_node_size_ = in;
534 template<
class LabelType =
double>
547 typedef std::map<std::string, double_array> map_type;
565 void to_classlabel(
int index, T & out)
const
567 out = T(classes[index]);
570 int to_classIndex(T index)
const
572 return std::find(classes.
begin(), classes.
end(), index) - classes.
begin();
575 #define EQUALS(field) field(rhs.field)
578 EQUALS(column_count_),
579 EQUALS(class_count_),
581 EQUALS(actual_mtry_),
582 EQUALS(actual_msample_),
583 EQUALS(problem_type_),
585 EQUALS(class_weights_),
586 EQUALS(is_weighted_),
588 EQUALS(response_size_)
590 std::back_insert_iterator<ArrayVector<Label_t> >
592 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
595 #define EQUALS(field) field(rhs.field)
599 EQUALS(column_count_),
600 EQUALS(class_count_),
602 EQUALS(actual_mtry_),
603 EQUALS(actual_msample_),
604 EQUALS(problem_type_),
606 EQUALS(class_weights_),
607 EQUALS(is_weighted_),
609 EQUALS(response_size_)
611 std::back_insert_iterator<ArrayVector<Label_t> >
613 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
617 #define EQUALS(field) (this->field = rhs.field);
620 EQUALS(column_count_);
621 EQUALS(class_count_);
623 EQUALS(actual_mtry_);
624 EQUALS(actual_msample_);
625 EQUALS(problem_type_);
627 EQUALS(is_weighted_);
629 EQUALS(response_size_)
630 class_weights_.clear();
631 std::back_insert_iterator<ArrayVector<
double> >
632 iter2(class_weights_);
633 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
635 std::back_insert_iterator<ArrayVector<
Label_t> >
637 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
644 EQUALS(column_count_);
645 EQUALS(class_count_);
647 EQUALS(actual_mtry_);
648 EQUALS(actual_msample_);
649 EQUALS(problem_type_);
651 EQUALS(is_weighted_);
653 EQUALS(response_size_)
654 class_weights_.clear();
655 std::back_insert_iterator<ArrayVector<
double> >
656 iter2(class_weights_);
657 std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
659 std::back_insert_iterator<ArrayVector<
Label_t> >
661 std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
667 bool operator==(ProblemSpec<T>
const & rhs)
670 #define COMPARE(field) result = result && (this->field == rhs.field);
671 COMPARE(column_count_);
672 COMPARE(class_count_);
674 COMPARE(actual_mtry_);
675 COMPARE(actual_msample_);
676 COMPARE(problem_type_);
677 COMPARE(is_weighted_);
680 COMPARE(class_weights_);
682 COMPARE(response_size_)
689 return !(*
this == rhs);
693 size_t serialized_size()
const
695 return 9 + class_count_ *int(is_weighted_+1);
700 void unserialize(Iter
const & begin, Iter
const & end)
703 vigra_precondition(end - begin >= 9,
704 "ProblemSpec::unserialize():"
705 "wrong number of parameters");
706 #define PULL(item_, type_) item_ = type_(*iter); ++iter;
707 PULL(column_count_,
int);
708 PULL(class_count_,
int);
710 vigra_precondition(end - begin >= 9 + class_count_,
711 "ProblemSpec::unserialize(): 1");
712 PULL(row_count_,
int);
713 PULL(actual_mtry_,
int);
714 PULL(actual_msample_,
int);
716 PULL(is_weighted_,
int);
718 PULL(precision_,
double);
719 PULL(response_size_,
int);
722 vigra_precondition(end - begin == 9 + 2*class_count_,
723 "ProblemSpec::unserialize(): 2");
724 class_weights_.insert(class_weights_.end(),
726 iter + class_count_);
727 iter += class_count_;
729 classes.insert(classes.end(), iter, end);
735 void serialize(Iter
const & begin, Iter
const & end)
const
738 vigra_precondition(end - begin == serialized_size(),
739 "RandomForestOptions::serialize():"
740 "wrong number of parameters");
741 #define PUSH(item_) *iter = double(item_); ++iter;
746 PUSH(actual_msample_);
751 PUSH(response_size_);
754 std::copy(class_weights_.begin(),
755 class_weights_.end(),
757 iter += class_count_;
759 std::copy(classes.begin(),
765 void make_from_map(map_type & in)
767 typedef MultiArrayShape<2>::type Shp;
768 #define PULL(item_, type_) item_ = type_(in[#item_][0]);
769 PULL(column_count_,
int);
770 PULL(class_count_,
int);
771 PULL(row_count_,
int);
772 PULL(actual_mtry_,
int);
773 PULL(actual_msample_,
int);
775 PULL(is_weighted_,
int);
777 PULL(precision_,
double);
778 PULL(response_size_,
int);
779 class_weights_ = in[
"class_weights_"];
782 void make_map(map_type & in)
const
784 typedef MultiArrayShape<2>::type Shp;
785 #define PUSH(item_) in[#item_] = double_array(1, double(item_));
790 PUSH(actual_msample_);
795 PUSH(response_size_);
796 in["class_weights_"] = class_weights_;
808 problem_type_(CHECKLATER),
826 template<
class C_Iter>
829 int size = end-begin;
830 for(
int k=0; k<size; ++k, ++begin)
831 classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
841 template<
class W_Iter>
844 class_weights_.insert(class_weights_.end(), begin, end);
855 class_weights_.clear();
860 problem_type_ = CHECKLATER;
861 is_weighted_ =
false;
885 int min_split_node_size_;
889 : min_split_node_size_(opt.min_split_node_size_)
893 void set_external_parameters(
ProblemSpec<T>const &,
int = 0,
bool =
false)
896 template<
class Region>
897 bool operator()(Region& region)
899 return region.size() < min_split_node_size_;
902 template<
class WeightIter,
class T,
class C>
912 #endif //VIGRA_RF_COMMON_HXX