SHOGUN v0.9.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _LINEARHMM_H__ 00013 #define _LINEARHMM_H__ 00014 00015 #include "features/StringFeatures.h" 00016 #include "features/Labels.h" 00017 #include "distributions/Distribution.h" 00018 00019 namespace shogun 00020 { 00039 class CLinearHMM : public CDistribution 00040 { 00041 public: 00043 CLinearHMM(); 00044 00049 CLinearHMM(CStringFeatures<uint16_t>* f); 00050 00056 CLinearHMM(int32_t p_num_features, int32_t p_num_symbols); 00057 00058 virtual ~CLinearHMM(); 00059 00068 virtual bool train(CFeatures* data=NULL); 00069 00077 bool train( 00078 const int32_t* indizes, int32_t num_indizes, 00079 float64_t pseudo_count); 00080 00087 float64_t get_log_likelihood_example(uint16_t* vector, int32_t len); 00088 00095 float64_t get_likelihood_example(uint16_t* vector, int32_t len); 00096 00102 virtual float64_t get_log_likelihood_example(int32_t num_example); 00103 00110 virtual float64_t get_log_derivative( 00111 int32_t num_param, int32_t num_example); 00112 00119 virtual inline float64_t get_log_derivative_obsolete( 00120 uint16_t obs, int32_t pos) 00121 { 00122 return 1.0/transition_probs[pos*num_symbols+obs]; 00123 } 00124 00131 virtual inline float64_t get_derivative_obsolete( 00132 uint16_t* vector, int32_t len, int32_t pos) 00133 { 00134 ASSERT(pos<len); 00135 return get_likelihood_example(vector, len)/transition_probs[pos*num_symbols+vector[pos]]; 00136 } 00137 00142 virtual inline int32_t get_sequence_length() { return sequence_length; } 00143 00148 virtual inline int32_t get_num_symbols() { return num_symbols; } 00149 00154 virtual inline int32_t get_num_model_parameters() { return num_params; } 00155 00162 virtual inline float64_t get_positional_log_parameter( 00163 uint16_t obs, int32_t position) 00164 { 00165 return log_transition_probs[position*num_symbols+obs]; 00166 } 00167 00173 virtual inline float64_t get_log_model_parameter(int32_t num_param) 00174 { 00175 ASSERT(log_transition_probs); 00176 ASSERT(num_param<num_params); 00177 00178 return log_transition_probs[num_param]; 00179 } 00180 00188 virtual void get_log_transition_probs(float64_t** dst, int32_t* num); 00189 00196 virtual bool set_log_transition_probs( 00197 const float64_t* src, int32_t num); 00198 00204 virtual void get_transition_probs(float64_t** dst, int32_t* num); 00205 00212 virtual bool set_transition_probs(const float64_t* src, int32_t num); 00213 00215 inline virtual const char* get_name() const { return "LinearHMM"; } 00216 00217 protected: 00218 virtual void load_serializable_post() throw (ShogunException); 00219 00220 private: 00221 void init(); 00222 00223 protected: 00225 int32_t sequence_length; 00227 int32_t num_symbols; 00229 int32_t num_params; 00231 float64_t* transition_probs; 00233 float64_t* log_transition_probs; 00234 }; 00235 } 00236 #endif