Labels.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Written (W) 1999-2008 Gunnar Raetsch
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #ifndef _LABELS__H__
00013 #define _LABELS__H__
00014 
00015 #include "lib/common.h"
00016 #include "lib/io.h"
00017 #include "base/SGObject.h"
00018 
00019 namespace shogun
00020 {
00026 class CLabels : public CSGObject
00027 {
00028     public:
00030         CLabels();
00031 
00036         CLabels(int32_t num_labels);
00037 
00043         CLabels(float64_t* src, int32_t len);
00044 
00049         CLabels(char* fname);
00050         virtual ~CLabels();
00051 
00057         bool load(char* fname);
00058 
00064         bool save(char* fname);
00065 
00072         inline bool set_label(int32_t idx, float64_t label)
00073         { 
00074             if (labels && idx<num_labels)
00075             {
00076                 labels[idx]=label;
00077                 return true;
00078             }
00079             else 
00080                 return false;
00081         }
00082 
00089         inline bool set_int_label(int32_t idx, int32_t label)
00090         { 
00091             if (labels && idx<num_labels)
00092             {
00093                 labels[idx]= (float64_t) label;
00094                 return true;
00095             }
00096             else 
00097                 return false;
00098         }
00099 
00105         inline float64_t get_label(int32_t idx)
00106         {
00107             if (labels && idx<num_labels)
00108                 return labels[idx];
00109             else
00110                 return -1;
00111         }
00112 
00118         inline int32_t get_int_label(int32_t idx)
00119         {
00120             if (labels && idx<num_labels)
00121             {
00122                 ASSERT(labels[idx]== ((float64_t) ((int32_t) labels[idx])));
00123                 return ((int32_t) labels[idx]);
00124             }
00125             else
00126                 return -1;
00127         }
00128 
00133         bool is_two_class_labeling();
00134 
00141         int32_t get_num_classes();
00142 
00149         float64_t* get_labels(int32_t &len);
00150         
00156         void get_labels(float64_t** dst, int32_t* len);
00157 
00163         void set_labels(float64_t* src, int32_t len);
00164 
00171         int32_t* get_int_labels(int32_t &len);
00172 
00179         void set_int_labels(int32_t *labels, int32_t len) ;
00180 
00185         inline int32_t get_num_labels() { return num_labels; }
00186 
00188         inline virtual const char* get_name() const { return "Labels"; }
00189 
00190 #ifdef HAVE_BOOST_SERIALIZATION
00191     private:
00192 
00193         // serialization needs to split up in save/load because 
00194         // the serialization of pointers to natives (int* & friends) 
00195         // requires a workaround 
00196         friend class ::boost::serialization::access;
00197         template<class Archive>
00198             void save(Archive & ar, const unsigned int archive_version) const
00199             {
00200 
00201                 SG_DEBUG("archiving Labels\n");
00202 
00203                 ar & ::boost::serialization::base_object<CSGObject>(*this);
00204 
00205                 ar & num_labels;
00206                 for (int32_t i=0; i < num_labels; ++i) 
00207                 {
00208                     ar & labels[i];
00209                 }
00210 
00211                 SG_DEBUG("done with Labels\n");
00212 
00213             }
00214 
00215         template<class Archive>
00216             void load(Archive & ar, const unsigned int archive_version)
00217             {
00218 
00219                 SG_DEBUG("archiving Labels\n");
00220 
00221                 ar & ::boost::serialization::base_object<CSGObject>(*this);
00222 
00223                 ar & num_labels;
00224 
00225                 SG_DEBUG("num_labels: %i\n", num_labels);
00226 
00227                 if (num_labels > 0)
00228                 {
00229 
00230                     labels = new float64_t[num_labels];
00231                     for (int32_t i=0; i< num_labels; ++i) 
00232                     {
00233                         ar & labels[i];
00234                     }
00235 
00236                 }
00237 
00238                 SG_DEBUG("done with Labels\n");
00239 
00240             }
00241 
00242         GLOBAL_BOOST_SERIALIZATION_SPLIT_MEMBER();
00243 
00244 
00245     public:
00246 
00247         virtual std::string toString() const
00248         {
00249             std::ostringstream s;
00250 
00251             ::boost::archive::text_oarchive oa(s);
00252 
00253             oa << *this;
00254 
00255             return s.str();
00256         }
00257 
00258 
00259         virtual void fromString(std::string str)
00260         {
00261 
00262             std::istringstream is(str);
00263 
00264             ::boost::archive::text_iarchive ia(is);
00265 
00266             ia >> *this;
00267 
00268         }
00269 #endif //HAVE_BOOST_SERIALIZATION
00270 
00271     protected:
00273         int32_t num_labels;
00275         float64_t* labels;
00276 };
00277 }
00278 #endif

SHOGUN Machine Learning Toolbox - Documentation