SHOGUN v0.9.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___ 00013 #define _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___ 00014 00015 #include "lib/common.h" 00016 #include "kernel/StringKernel.h" 00017 #include "kernel/WeightedDegreeStringKernel.h" 00018 #include "lib/Trie.h" 00019 00020 namespace shogun 00021 { 00022 00023 class CSVM; 00024 00048 class CWeightedDegreePositionStringKernel: public CStringKernel<char> 00049 { 00050 public: 00052 CWeightedDegreePositionStringKernel(void); 00053 00061 CWeightedDegreePositionStringKernel( 00062 int32_t size, int32_t degree, 00063 int32_t max_mismatch=0, int32_t mkl_stepsize=1); 00064 00075 CWeightedDegreePositionStringKernel( 00076 int32_t size, float64_t* weights, int32_t degree, 00077 int32_t max_mismatch, int32_t* shift, int32_t shift_len, 00078 int32_t mkl_stepsize=1); 00079 00086 CWeightedDegreePositionStringKernel( 00087 CStringFeatures<char>* l, CStringFeatures<char>* r, int32_t degree); 00088 00089 virtual ~CWeightedDegreePositionStringKernel(); 00090 00097 virtual bool init(CFeatures* l, CFeatures* r); 00098 00100 virtual void cleanup(); 00101 00106 virtual EKernelType get_kernel_type() { return K_WEIGHTEDDEGREEPOS; } 00107 00112 virtual const char* get_name() const { return "WeightedDegreePos"; } 00113 00121 inline virtual bool init_optimization( 00122 int32_t p_count, int32_t *IDX, float64_t * alphas) 00123 { 00124 return init_optimization(p_count, IDX, alphas, -1); 00125 } 00126 00138 virtual bool init_optimization( 00139 int32_t count, int32_t *IDX, float64_t * alphas, int32_t tree_num, 00140 int32_t upto_tree=-1); 00141 00146 virtual bool delete_optimization(); 00147 00153 inline virtual float64_t compute_optimized(int32_t idx) 00154 { 00155 ASSERT(get_is_initialized()); 00156 ASSERT(alphabet); 00157 ASSERT(alphabet->get_alphabet()==DNA || alphabet->get_alphabet()==RNA); 00158 return compute_by_tree(idx); 00159 } 00160 00165 static void* compute_batch_helper(void* p); 00166 00177 virtual void compute_batch( 00178 int32_t num_vec, int32_t* vec_idx, float64_t* target, 00179 int32_t num_suppvec, int32_t* IDX, float64_t* alphas, 00180 float64_t factor=1.0); 00181 00185 inline virtual void clear_normal() 00186 { 00187 if ((opt_type==FASTBUTMEMHUNGRY) && (tries.get_use_compact_terminal_nodes())) 00188 { 00189 tries.set_use_compact_terminal_nodes(false) ; 00190 SG_DEBUG( "disabling compact trie nodes with FASTBUTMEMHUNGRY\n") ; 00191 } 00192 00193 if (get_is_initialized()) 00194 { 00195 if (opt_type==SLOWBUTMEMEFFICIENT) 00196 tries.delete_trees(true); 00197 else if (opt_type==FASTBUTMEMHUNGRY) 00198 tries.delete_trees(false); // still buggy 00199 else 00200 SG_ERROR( "unknown optimization type\n"); 00201 00202 set_is_initialized(false); 00203 } 00204 } 00205 00211 inline virtual void add_to_normal(int32_t idx, float64_t weight) 00212 { 00213 add_example_to_tree(idx, weight); 00214 set_is_initialized(true); 00215 } 00216 00221 inline virtual int32_t get_num_subkernels() 00222 { 00223 if (position_weights!=NULL) 00224 return (int32_t) ceil(1.0*seq_length/mkl_stepsize) ; 00225 if (length==0) 00226 return (int32_t) ceil(1.0*get_degree()/mkl_stepsize); 00227 return (int32_t) ceil(1.0*get_degree()*length/mkl_stepsize) ; 00228 } 00229 00235 inline void compute_by_subkernel( 00236 int32_t idx, float64_t * subkernel_contrib) 00237 { 00238 if (get_is_initialized()) 00239 { 00240 compute_by_tree(idx, subkernel_contrib); 00241 return ; 00242 } 00243 00244 SG_ERROR( "CWeightedDegreePositionStringKernel optimization not initialized\n") ; 00245 } 00246 00252 inline const float64_t* get_subkernel_weights(int32_t& num_weights) 00253 { 00254 num_weights = get_num_subkernels() ; 00255 00256 delete[] weights_buffer ; 00257 weights_buffer = new float64_t[num_weights] ; 00258 00259 if (position_weights!=NULL) 00260 for (int32_t i=0; i<num_weights; i++) 00261 weights_buffer[i] = position_weights[i*mkl_stepsize] ; 00262 else 00263 for (int32_t i=0; i<num_weights; i++) 00264 weights_buffer[i] = weights[i*mkl_stepsize] ; 00265 00266 return weights_buffer ; 00267 } 00268 00274 inline void set_subkernel_weights( 00275 float64_t* weights2, int32_t num_weights2) 00276 { 00277 int32_t num_weights = get_num_subkernels() ; 00278 if (num_weights!=num_weights2) 00279 SG_ERROR( "number of weights do not match\n") ; 00280 00281 if (position_weights!=NULL) 00282 for (int32_t i=0; i<num_weights; i++) 00283 for (int32_t j=0; j<mkl_stepsize; j++) 00284 { 00285 if (i*mkl_stepsize+j<seq_length) 00286 position_weights[i*mkl_stepsize+j] = weights2[i] ; 00287 } 00288 else if (length==0) 00289 { 00290 for (int32_t i=0; i<num_weights; i++) 00291 for (int32_t j=0; j<mkl_stepsize; j++) 00292 if (i*mkl_stepsize+j<get_degree()) 00293 weights[i*mkl_stepsize+j] = weights2[i] ; 00294 } 00295 else 00296 { 00297 for (int32_t i=0; i<num_weights; i++) 00298 for (int32_t j=0; j<mkl_stepsize; j++) 00299 if (i*mkl_stepsize+j<get_degree()*length) 00300 weights[i*mkl_stepsize+j] = weights2[i] ; 00301 } 00302 } 00303 00304 // other kernel tree operations 00310 float64_t* compute_abs_weights(int32_t & len); 00311 00316 bool is_tree_initialized() { return tree_initialized; } 00317 00322 inline int32_t get_max_mismatch() { return max_mismatch; } 00323 00328 inline int32_t get_degree() { return degree; } 00329 00335 inline float64_t *get_degree_weights(int32_t& d, int32_t& len) 00336 { 00337 d=degree; 00338 len=length; 00339 return weights; 00340 } 00341 00347 inline float64_t *get_weights(int32_t& num_weights) 00348 { 00349 if (position_weights!=NULL) 00350 { 00351 num_weights = seq_length ; 00352 return position_weights ; 00353 } 00354 if (length==0) 00355 num_weights = degree ; 00356 else 00357 num_weights = degree*length ; 00358 return weights; 00359 } 00360 00366 inline float64_t *get_position_weights(int32_t& len) 00367 { 00368 len=seq_length; 00369 return position_weights; 00370 } 00371 00377 bool set_shifts(int32_t* shifts, int32_t len); 00378 00385 virtual bool set_weights(float64_t* weights, int32_t d, int32_t len=0); 00386 00391 virtual bool set_wd_weights(); 00392 00399 virtual bool set_position_weights(float64_t* pws, int32_t len); 00400 00408 bool set_position_weights_lhs(float64_t* pws, int32_t len, int32_t num); 00409 00417 bool set_position_weights_rhs(float64_t* pws, int32_t len, int32_t num); 00418 00423 bool init_block_weights(); 00424 00429 bool init_block_weights_from_wd(); 00430 00435 bool init_block_weights_from_wd_external(); 00436 00441 bool init_block_weights_const(); 00442 00447 bool init_block_weights_linear(); 00448 00453 bool init_block_weights_sqpoly(); 00454 00459 bool init_block_weights_cubicpoly(); 00460 00465 bool init_block_weights_exp(); 00466 00471 bool init_block_weights_log(); 00472 00477 bool delete_position_weights() 00478 { 00479 delete[] position_weights; 00480 position_weights=NULL; 00481 return true; 00482 } 00483 00488 bool delete_position_weights_lhs() 00489 { 00490 delete[] position_weights_lhs; 00491 position_weights_lhs=NULL; 00492 return true; 00493 } 00494 00499 bool delete_position_weights_rhs() 00500 { 00501 delete[] position_weights_rhs; 00502 position_weights_rhs=NULL; 00503 return true; 00504 } 00505 00511 virtual float64_t compute_by_tree(int32_t idx); 00512 00518 virtual void compute_by_tree(int32_t idx, float64_t* LevelContrib); 00519 00532 float64_t* compute_scoring( 00533 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, 00534 float64_t* target, int32_t num_suppvec, int32_t* IDX, 00535 float64_t* weights); 00536 00545 char* compute_consensus( 00546 int32_t &num_feat, int32_t num_suppvec, int32_t* IDX, 00547 float64_t* alphas); 00548 00560 float64_t* extract_w( 00561 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, 00562 float64_t* w_result, int32_t num_suppvec, int32_t* IDX, 00563 float64_t* alphas); 00564 00577 float64_t* compute_POIM( 00578 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, 00579 float64_t* poim_result, int32_t num_suppvec, int32_t* IDX, 00580 float64_t* alphas, float64_t* distrib); 00581 00588 void prepare_POIM2( 00589 float64_t* distrib, int32_t num_sym, int32_t num_feat); 00590 00597 void compute_POIM2(int32_t max_degree, CSVM* svm); 00598 00604 void get_POIM2(float64_t** poim, int32_t* result_len); 00605 00607 void cleanup_POIM2(); 00608 00609 protected: 00611 void create_empty_tries(); 00612 00618 virtual void add_example_to_tree( 00619 int32_t idx, float64_t weight); 00620 00627 void add_example_to_single_tree( 00628 int32_t idx, float64_t weight, int32_t tree_num); 00629 00638 virtual float64_t compute(int32_t idx_a, int32_t idx_b); 00639 00648 float64_t compute_with_mismatch( 00649 char* avec, int32_t alen, char* bvec, int32_t blen); 00650 00659 float64_t compute_without_mismatch( 00660 char* avec, int32_t alen, char* bvec, int32_t blen); 00661 00670 float64_t compute_without_mismatch_matrix( 00671 char* avec, int32_t alen, char* bvec, int32_t blen); 00672 00683 float64_t compute_without_mismatch_position_weights( 00684 char* avec, float64_t *posweights_lhs, int32_t alen, 00685 char* bvec, float64_t *posweights_rhs, int32_t blen); 00686 00688 virtual void remove_lhs(); 00689 00698 virtual void load_serializable_post(void) throw (ShogunException); 00699 00700 private: 00703 void init(); 00704 00705 protected: 00707 float64_t* weights; 00709 int32_t weights_degree; 00711 int32_t weights_length; 00712 00714 float64_t* position_weights; 00716 int32_t position_weights_len; 00717 00719 float64_t* position_weights_lhs; 00721 int32_t position_weights_lhs_len; 00723 float64_t* position_weights_rhs; 00725 int32_t position_weights_rhs_len; 00727 bool* position_mask; 00728 00730 float64_t* weights_buffer; 00732 int32_t mkl_stepsize; 00733 00735 int32_t degree; 00737 int32_t length; 00738 00740 int32_t max_mismatch; 00742 int32_t seq_length; 00743 00745 int32_t *shift; 00747 int32_t shift_len; 00749 int32_t max_shift; 00750 00752 bool block_computation; 00753 00755 float64_t* block_weights; 00757 EWDKernType type; 00759 int32_t which_degree; 00760 00762 CTrie<DNATrie> tries; 00764 CTrie<POIMTrie> poim_tries; 00765 00767 bool tree_initialized; 00769 bool use_poim_tries; 00770 00772 float64_t* m_poim_distrib; 00774 float64_t* m_poim; 00775 00777 int32_t m_poim_num_sym; 00779 int32_t m_poim_num_feat; 00781 int32_t m_poim_result_len; 00782 00784 CAlphabet* alphabet; 00785 }; 00786 } 00787 #endif /* _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H__ */