00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014 #include "kernel/OligoStringKernel.h"
00015 #include "kernel/SqrtDiagKernelNormalizer.h"
00016 #include "features/StringFeatures.h"
00017
00018 #include <map>
00019 #include <vector>
00020 #include <algorithm>
00021
00022 using namespace shogun;
00023
00024 COligoStringKernel::COligoStringKernel(int32_t cache_sz, int32_t kmer_len, float64_t w)
00025 : CStringKernel<char>(cache_sz), k(kmer_len), width(w), gauss_table(NULL)
00026 {
00027 set_normalizer(new CSqrtDiagKernelNormalizer());
00028 }
00029
00030 COligoStringKernel::~COligoStringKernel()
00031 {
00032 cleanup();
00033 }
00034
00035 void COligoStringKernel::cleanup()
00036 {
00037 delete[] gauss_table;
00038 gauss_table=NULL;
00039
00040 CKernel::cleanup();
00041 }
00042
00043 bool COligoStringKernel::init(CFeatures* l, CFeatures* r)
00044 {
00045 cleanup();
00046
00047 CStringKernel<char>::init(l,r);
00048 int32_t max_len=CMath::max(
00049 ((CStringFeatures<char>*) l)->get_max_vector_length(),
00050 ((CStringFeatures<char>*) r)->get_max_vector_length()
00051 );
00052
00053 getExpFunctionCache(max_len);
00054 return init_normalizer();
00055 }
00056
00057 void COligoStringKernel::encodeOligo(
00058 const std::string& sequence, uint32_t k_mer_length,
00059 const std::string& allowed_characters,
00060 std::vector< std::pair<int32_t, float64_t> >& values)
00061 {
00062 float64_t oligo_value = 0.;
00063 float64_t factor = 1.;
00064 std::map<std::string::value_type, uint32_t> residue_values;
00065 uint32_t counter = 0;
00066 uint32_t number_of_residues = allowed_characters.size();
00067 uint32_t sequence_length = sequence.size();
00068 bool sequence_ok = true;
00069
00070
00071 for (uint32_t i = 0; i < sequence.size(); ++i)
00072 {
00073 if (allowed_characters.find(sequence.at(i)) == std::string::npos)
00074 sequence_ok = false;
00075 }
00076
00077 if (sequence_ok && k_mer_length <= sequence_length)
00078 {
00079 values.resize(sequence_length - k_mer_length + 1,
00080 std::pair<int32_t, float64_t>());
00081 for (uint32_t i = 0; i < number_of_residues; ++i)
00082 {
00083 residue_values.insert(std::make_pair(allowed_characters[i], counter));
00084 ++counter;
00085 }
00086 for (int32_t k = k_mer_length - 1; k >= 0; k--)
00087 {
00088 oligo_value += factor * residue_values[sequence[k]];
00089 factor *= number_of_residues;
00090 }
00091 factor /= number_of_residues;
00092 counter = 0;
00093 values[counter].first = 1;
00094 values[counter].second = oligo_value;
00095 ++counter;
00096
00097 for (uint32_t j = 1; j < sequence_length - k_mer_length + 1; j++)
00098 {
00099 oligo_value -= factor * residue_values[sequence[j - 1]];
00100 oligo_value = oligo_value * number_of_residues +
00101 residue_values[sequence[j + k_mer_length - 1]];
00102
00103 values[counter].first = j + 1;
00104 values[counter].second = oligo_value ;
00105 ++counter;
00106 }
00107 stable_sort(values.begin(), values.end(), cmpOligos_);
00108 }
00109 else
00110 {
00111 values.clear();
00112 }
00113 }
00114
00115 void COligoStringKernel::getSequences(
00116 const std::vector<std::string>& sequences, uint32_t k_mer_length,
00117 const std::string& allowed_characters,
00118 std::vector< std::vector< std::pair<int32_t, float64_t> > >& encoded_sequences)
00119 {
00120 std::vector< std::pair<int32_t, float64_t> > temp_vector;
00121 encoded_sequences.resize(sequences.size(),
00122 std::vector< std::pair<int32_t, float64_t> >());
00123
00124 for (uint32_t i = 0; i < sequences.size(); ++i)
00125 {
00126 encodeOligo(sequences[i], k_mer_length, allowed_characters, temp_vector);
00127 encoded_sequences[i] = temp_vector;
00128 }
00129 }
00130
00131 void COligoStringKernel::getExpFunctionCache(uint32_t sequence_length)
00132 {
00133 delete[] gauss_table;
00134 gauss_table=new float64_t[sequence_length];
00135
00136 gauss_table[0] = 1;
00137 for (uint32_t i = 1; i < sequence_length - 1; i++)
00138 gauss_table[i] = exp((-1 / (CMath::sq(width))) * CMath::sq(i));
00139 }
00140
00141 float64_t COligoStringKernel::kernelOligoFast(
00142 const std::vector< std::pair<int32_t, float64_t> >& x,
00143 const std::vector< std::pair<int32_t, float64_t> >& y,
00144 int32_t max_distance)
00145 {
00146 float64_t result = 0;
00147 int32_t i1 = 0;
00148 int32_t i2 = 0;
00149 int32_t c1 = 0;
00150 uint32_t x_size = x.size();
00151 uint32_t y_size = y.size();
00152
00153 while ((uint32_t) i1 < x_size && (uint32_t) i2 < y_size)
00154 {
00155 if (x[i1].second == y[i2].second)
00156 {
00157 if (max_distance < 0
00158 || (abs(x[i1].first - y[i2].first)) <= max_distance)
00159 {
00160 result += gauss_table[abs((x[i1].first - y[i2].first))];
00161 if (x[i1].second == x[i1 + 1].second)
00162 {
00163 i1++;
00164 c1++;
00165 }
00166 else if (y[i2].second == y[i2 + 1].second)
00167 {
00168 i2++;
00169 i1 -= c1;
00170 c1 = 0;
00171 }
00172 else
00173 {
00174 i1++;
00175 i2++;
00176 }
00177 }
00178 else
00179 {
00180 if (x[i1].first < y[i2].first)
00181 {
00182 if (x[i1].second == x[i1 + 1].second)
00183 {
00184 i1++;
00185 }
00186 else if (y[i2].second == y[i2 + 1].second)
00187 {
00188 while(y[i2++].second == y[i2].second)
00189 {
00190 ;
00191 }
00192 ++i1;
00193 c1 = 0;
00194 }
00195 else
00196 {
00197 i1++;
00198 i2++;
00199 c1 = 0;
00200 }
00201 }
00202 else
00203 {
00204 i2++;
00205 i1 -= c1;
00206 c1 = 0;
00207 }
00208 }
00209 }
00210 else
00211 {
00212 if (x[i1].second < y[i2].second)
00213 i1++;
00214 else
00215 i2++;
00216 c1 = 0;
00217 }
00218 }
00219 return result;
00220 }
00221
00222
00223 float64_t COligoStringKernel::compute(int32_t idx_a, int32_t idx_b)
00224 {
00225 int32_t alen, blen;
00226 bool free_a, free_b;
00227 char* avec=((CStringFeatures<char>*) lhs)->get_feature_vector(idx_a, alen, free_a);
00228 char* bvec=((CStringFeatures<char>*) rhs)->get_feature_vector(idx_b, blen, free_b);
00229 std::vector< std::pair<int32_t, float64_t> > aenc;
00230 std::vector< std::pair<int32_t, float64_t> > benc;
00231 encodeOligo(std::string(avec, alen), k, "ACGT", aenc);
00232 encodeOligo(std::string(bvec, alen), k, "ACGT", benc);
00233 float64_t result=kernelOligoFast(aenc, benc);
00234 ((CStringFeatures<char>*) lhs)->free_feature_vector(avec, idx_a, free_a);
00235 ((CStringFeatures<char>*) rhs)->free_feature_vector(bvec, idx_b, free_b);
00236 return result;
00237 }