SHOGUN v0.9.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _PLUGINESTIMATE_H___ 00012 #define _PLUGINESTIMATE_H___ 00013 00014 #include "classifier/Classifier.h" 00015 #include "features/StringFeatures.h" 00016 #include "features/Labels.h" 00017 #include "distributions/LinearHMM.h" 00018 00019 namespace shogun 00020 { 00034 class CPluginEstimate: public CClassifier 00035 { 00036 public: 00041 CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10); 00042 virtual ~CPluginEstimate(); 00043 00052 virtual bool train(CFeatures* data=NULL); 00053 00058 CLabels* classify(); 00059 00065 virtual CLabels* classify(CFeatures* data); 00066 00071 virtual inline void set_features(CStringFeatures<uint16_t>* feat) 00072 { 00073 SG_UNREF(features); 00074 SG_REF(feat); 00075 features=feat; 00076 } 00077 00082 virtual CStringFeatures<uint16_t>* get_features() { SG_REF(features); return features; } 00083 00085 float64_t classify_example(int32_t vec_idx); 00086 00093 inline float64_t posterior_log_odds_obsolete( 00094 uint16_t* vector, int32_t len) 00095 { 00096 return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len); 00097 } 00098 00105 inline float64_t get_parameterwise_log_odds( 00106 uint16_t obs, int32_t position) 00107 { 00108 return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position); 00109 } 00110 00117 inline float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos) 00118 { 00119 return pos_model->get_log_derivative_obsolete(obs, pos); 00120 } 00121 00128 inline float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos) 00129 { 00130 return neg_model->get_log_derivative_obsolete(obs, pos); 00131 } 00132 00141 inline bool get_model_params( 00142 float64_t*& pos_params, float64_t*& neg_params, 00143 int32_t &seq_length, int32_t &num_symbols) 00144 { 00145 int32_t num; 00146 00147 if ((!pos_model) || (!neg_model)) 00148 { 00149 SG_ERROR( "no model available\n"); 00150 return false; 00151 } 00152 00153 pos_model->get_log_transition_probs(&pos_params, &num); 00154 neg_model->get_log_transition_probs(&neg_params, &num); 00155 00156 seq_length = pos_model->get_sequence_length(); 00157 num_symbols = pos_model->get_num_symbols(); 00158 ASSERT(pos_model->get_num_model_parameters()==neg_model->get_num_model_parameters()); 00159 ASSERT(pos_model->get_num_symbols()==neg_model->get_num_symbols()); 00160 return true; 00161 } 00162 00169 inline void set_model_params( 00170 const float64_t* pos_params, const float64_t* neg_params, 00171 int32_t seq_length, int32_t num_symbols) 00172 { 00173 int32_t num_params; 00174 00175 SG_UNREF(pos_model); 00176 pos_model=new CLinearHMM(seq_length, num_symbols); 00177 SG_REF(pos_model); 00178 00179 00180 SG_UNREF(neg_model); 00181 neg_model=new CLinearHMM(seq_length, num_symbols); 00182 SG_REF(neg_model); 00183 00184 num_params=pos_model->get_num_model_parameters(); 00185 ASSERT(seq_length*num_symbols==num_params); 00186 ASSERT(num_params==neg_model->get_num_model_parameters()); 00187 00188 pos_model->set_log_transition_probs(pos_params, num_params); 00189 neg_model->set_log_transition_probs(neg_params, num_params); 00190 } 00191 00196 inline int32_t get_num_params() 00197 { 00198 return pos_model->get_num_model_parameters()+neg_model->get_num_model_parameters(); 00199 } 00200 00205 inline bool check_models() 00206 { 00207 return ( (pos_model!=NULL) && (neg_model!=NULL) ); 00208 } 00209 00211 inline virtual const char* get_name() const { return "PluginEstimate"; } 00212 00213 protected: 00215 float64_t m_pos_pseudo; 00217 float64_t m_neg_pseudo; 00218 00220 CLinearHMM* pos_model; 00222 CLinearHMM* neg_model; 00223 00225 CStringFeatures<uint16_t>* features; 00226 }; 00227 } 00228 #endif