[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_online_prediction_set.hxx
00001 #include "../multi_array.hxx"
00002 #include <set>
00003 #include <vector>
00004 
00005 namespace vigra
00006 {
00007 
00008 template<class T>
00009 struct SampleRange
00010 {
00011     SampleRange(int start,int end,int num_features)
00012     {
00013         this->start=start;
00014         this->end=end;
00015         this->min_boundaries.resize(num_features,-FLT_MAX);
00016         this->max_boundaries.resize(num_features,FLT_MAX);
00017     }
00018     int start;
00019     mutable int end;
00020     mutable std::vector<T> max_boundaries;
00021     mutable std::vector<T> min_boundaries;
00022     bool operator<(const SampleRange& o) const
00023     {
00024         return o.start<start;
00025     }
00026 };
00027 
00028 template<class T>
00029 class OnlinePredictionSet
00030 {
00031 public:
00032     template<class U>
00033     OnlinePredictionSet(MultiArrayView<2,T,U>& features,int num_sets)
00034     {
00035         this->features=features;
00036         std::vector<int> init(features.shape(0));
00037         for(unsigned int i=0;i<init.size();++i)
00038             init[i]=i;
00039         indices.resize(num_sets,init);
00040         std::set<SampleRange<T> > set_init;
00041         set_init.insert(SampleRange<T>(0,init.size(),features.shape(1)));
00042         ranges.resize(num_sets,set_init);
00043     cumulativePredTime.resize(num_sets,0);
00044     }
00045     int get_worsed_tree()
00046     {
00047         int result=0;
00048     for(unsigned int i=0;i<cumulativePredTime.size();++i)
00049     {
00050         if(cumulativePredTime[i]>cumulativePredTime[result])
00051         {
00052             result=i;
00053         }
00054     }
00055     return result;
00056     }
00057     void reset_tree(int index)
00058     {
00059         index=index % ranges.size();
00060         std::set<SampleRange<T> > set_init;
00061         set_init.insert(SampleRange<T>(0,features.shape(0),features.shape(1)));
00062         ranges[index]=set_init;
00063     cumulativePredTime[index]=0;
00064     }
00065     std::vector<std::set<SampleRange<T> > > ranges;
00066     std::vector<std::vector<int> > indices;
00067     std::vector<int> cumulativePredTime;
00068     MultiArray<2,T> features;
00069 };
00070 
00071 }
00072 

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.0 (Thu Aug 25 2011)