Labels.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _LABELS__H__
00013 #define _LABELS__H__
00014
00015 #include "lib/common.h"
00016 #include "lib/io.h"
00017 #include "base/SGObject.h"
00018
00019 namespace shogun
00020 {
00026 class CLabels : public CSGObject
00027 {
00028 public:
00030 CLabels();
00031
00036 CLabels(int32_t num_labels);
00037
00043 CLabels(float64_t* src, int32_t len);
00044
00049 CLabels(char* fname);
00050 virtual ~CLabels();
00051
00057 bool load(char* fname);
00058
00064 bool save(char* fname);
00065
00072 inline bool set_label(int32_t idx, float64_t label)
00073 {
00074 if (labels && idx<num_labels)
00075 {
00076 labels[idx]=label;
00077 return true;
00078 }
00079 else
00080 return false;
00081 }
00082
00089 inline bool set_int_label(int32_t idx, int32_t label)
00090 {
00091 if (labels && idx<num_labels)
00092 {
00093 labels[idx]= (float64_t) label;
00094 return true;
00095 }
00096 else
00097 return false;
00098 }
00099
00105 inline float64_t get_label(int32_t idx)
00106 {
00107 if (labels && idx<num_labels)
00108 return labels[idx];
00109 else
00110 return -1;
00111 }
00112
00118 inline int32_t get_int_label(int32_t idx)
00119 {
00120 if (labels && idx<num_labels)
00121 {
00122 ASSERT(labels[idx]== ((float64_t) ((int32_t) labels[idx])));
00123 return ((int32_t) labels[idx]);
00124 }
00125 else
00126 return -1;
00127 }
00128
00133 bool is_two_class_labeling();
00134
00141 int32_t get_num_classes();
00142
00149 float64_t* get_labels(int32_t &len);
00150
00156 void get_labels(float64_t** dst, int32_t* len);
00157
00163 void set_labels(float64_t* src, int32_t len);
00164
00171 int32_t* get_int_labels(int32_t &len);
00172
00179 void set_int_labels(int32_t *labels, int32_t len) ;
00180
00185 inline int32_t get_num_labels() { return num_labels; }
00186
00188 inline virtual const char* get_name() const { return "Labels"; }
00189
00190 #ifdef HAVE_BOOST_SERIALIZATION
00191 private:
00192
00193
00194
00195
00196 friend class ::boost::serialization::access;
00197 template<class Archive>
00198 void save(Archive & ar, const unsigned int archive_version) const
00199 {
00200
00201 SG_DEBUG("archiving Labels\n");
00202
00203 ar & ::boost::serialization::base_object<CSGObject>(*this);
00204
00205 ar & num_labels;
00206 for (int32_t i=0; i < num_labels; ++i)
00207 {
00208 ar & labels[i];
00209 }
00210
00211 SG_DEBUG("done with Labels\n");
00212
00213 }
00214
00215 template<class Archive>
00216 void load(Archive & ar, const unsigned int archive_version)
00217 {
00218
00219 SG_DEBUG("archiving Labels\n");
00220
00221 ar & ::boost::serialization::base_object<CSGObject>(*this);
00222
00223 ar & num_labels;
00224
00225 SG_DEBUG("num_labels: %i\n", num_labels);
00226
00227 if (num_labels > 0)
00228 {
00229
00230 labels = new float64_t[num_labels];
00231 for (int32_t i=0; i< num_labels; ++i)
00232 {
00233 ar & labels[i];
00234 }
00235
00236 }
00237
00238 SG_DEBUG("done with Labels\n");
00239
00240 }
00241
00242 GLOBAL_BOOST_SERIALIZATION_SPLIT_MEMBER();
00243
00244
00245 public:
00246
00247 virtual std::string toString() const
00248 {
00249 std::ostringstream s;
00250
00251 ::boost::archive::text_oarchive oa(s);
00252
00253 oa << *this;
00254
00255 return s.str();
00256 }
00257
00258
00259 virtual void fromString(std::string str)
00260 {
00261
00262 std::istringstream is(str);
00263
00264 ::boost::archive::text_iarchive ia(is);
00265
00266 ia >> *this;
00267
00268 }
00269 #endif //HAVE_BOOST_SERIALIZATION
00270
00271 protected:
00273 int32_t num_labels;
00275 float64_t* labels;
00276 };
00277 }
00278 #endif