dune-istl
2.2.0
|
00001 #ifndef DUNE_MultiTypeMATRIX_HH 00002 #define DUNE_MultiTypeMATRIX_HH 00003 00004 #include<cmath> 00005 #include<iostream> 00006 00007 #include "istlexception.hh" 00008 00009 #if HAVE_BOOST 00010 #ifdef HAVE_BOOST_FUSION 00011 00012 #include <boost/fusion/sequence.hpp> 00013 #include <boost/fusion/container.hpp> 00014 #include <boost/fusion/iterator.hpp> 00015 #include <boost/typeof/typeof.hpp> 00016 #include <boost/fusion/algorithm.hpp> 00017 00018 namespace mpl=boost::mpl; 00019 namespace fusion=boost::fusion; 00020 00021 // forward decl 00022 namespace Dune 00023 { 00024 template<typename T1, typename T2=fusion::void_, typename T3=fusion::void_, typename T4=fusion::void_, 00025 typename T5=fusion::void_, typename T6=fusion::void_, typename T7=fusion::void_, 00026 typename T8=fusion::void_, typename T9=fusion::void_> 00027 class MultiTypeBlockMatrix; 00028 00029 template<int I, int crow, int remain_row> 00030 class MultiTypeBlockMatrix_Solver; 00031 } 00032 00033 #include "gsetc.hh" 00034 00035 namespace Dune { 00036 00054 template<int crow, int remain_rows, int ccol, int remain_cols, 00055 typename TMatrix> 00056 class MultiTypeBlockMatrix_Print { 00057 public: 00058 00062 static void print(const TMatrix& m) { 00063 std::cout << "\t(" << crow << ", " << ccol << "): \n" << fusion::at_c<ccol>( fusion::at_c<crow>(m)); 00064 MultiTypeBlockMatrix_Print<crow,remain_rows,ccol+1,remain_cols-1,TMatrix>::print(m); //next column 00065 } 00066 }; 00067 template<int crow, int remain_rows, int ccol, typename TMatrix> //specialization for remain_cols=0 00068 class MultiTypeBlockMatrix_Print<crow,remain_rows,ccol,0,TMatrix> { 00069 public: static void print(const TMatrix& m) { 00070 static const int xlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value; 00071 MultiTypeBlockMatrix_Print<crow+1,remain_rows-1,0,xlen,TMatrix>::print(m); //next row 00072 } 00073 }; 00074 00075 template<int crow, int ccol, int remain_cols, typename TMatrix> //recursion end: specialization for remain_rows=0 00076 class MultiTypeBlockMatrix_Print<crow,0,ccol,remain_cols,TMatrix> { 00077 public: 00078 static void print(const TMatrix& m) 00079 {std::cout << std::endl;} 00080 }; 00081 00082 00083 00084 //make MultiTypeBlockVector_Ident known (for MultiTypeBlockMatrix_Ident) 00085 template<int count, typename T1, typename T2> 00086 class MultiTypeBlockVector_Ident; 00087 00088 00101 template<int rowcount, typename T1, typename T2> 00102 class MultiTypeBlockMatrix_Ident { 00103 public: 00104 00109 static void equalize(T1& a, const T2& b) { 00110 MultiTypeBlockVector_Ident< mpl::size< typename mpl::at_c<T1,rowcount-1>::type >::value ,T1,T2>::equalize(a,b); //rows are cvectors 00111 MultiTypeBlockMatrix_Ident<rowcount-1,T1,T2>::equalize(a,b); //iterate over rows 00112 } 00113 }; 00114 00115 //recursion end for rowcount=0 00116 template<typename T1, typename T2> 00117 class MultiTypeBlockMatrix_Ident<0,T1,T2> { 00118 public: 00119 static void equalize (T1& a, const T2& b) 00120 {} 00121 }; 00122 00128 template<int crow, int remain_rows, int ccol, int remain_cols, 00129 typename TVecY, typename TMatrix, typename TVecX> 00130 class MultiTypeBlockMatrix_VectMul { 00131 public: 00132 00136 static void umv(TVecY& y, const TMatrix& A, const TVecX& x) { 00137 fusion::at_c<ccol>( fusion::at_c<crow>(A) ).umv( fusion::at_c<ccol>(x), fusion::at_c<crow>(y) ); 00138 MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::umv(y, A, x); 00139 } 00140 00144 static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) { 00145 fusion::at_c<ccol>( fusion::at_c<crow>(A) ).mmv( fusion::at_c<ccol>(x), fusion::at_c<crow>(y) ); 00146 MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::mmv(y, A, x); 00147 } 00148 00149 template<typename AlphaType> 00150 static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) { 00151 fusion::at_c<ccol>( fusion::at_c<crow>(A) ).usmv(alpha, fusion::at_c<ccol>(x), fusion::at_c<crow>(y) ); 00152 MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x); 00153 } 00154 00155 00156 }; 00157 00158 //specialization for remain_cols = 0 00159 template<int crow, int remain_rows,int ccol, typename TVecY, 00160 typename TMatrix, typename TVecX> 00161 class MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol,0,TVecY,TMatrix,TVecX> { //start iteration over next row 00162 00163 public: 00167 static void umv(TVecY& y, const TMatrix& A, const TVecX& x) { 00168 static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value; 00169 MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::umv(y, A, x); 00170 } 00171 00175 static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) { 00176 static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value; 00177 MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::mmv(y, A, x); 00178 } 00179 00180 template <typename AlphaType> 00181 static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) { 00182 static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value; 00183 MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x); 00184 } 00185 }; 00186 00187 //specialization for remain_rows = 0 00188 template<int crow, int ccol, int remain_cols, typename TVecY, 00189 typename TMatrix, typename TVecX> 00190 class MultiTypeBlockMatrix_VectMul<crow,0,ccol,remain_cols,TVecY,TMatrix,TVecX> { 00191 //end recursion 00192 public: 00193 static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {} 00194 static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {} 00195 00196 template<typename AlphaType> 00197 static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {} 00198 }; 00199 00200 00201 00202 00203 00204 00213 template<typename T1, typename T2, typename T3, typename T4, 00214 typename T5, typename T6, typename T7, typename T8, typename T9> 00215 class MultiTypeBlockMatrix : public fusion::vector<T1, T2, T3, T4, T5, T6, T7, T8, T9> { 00216 00217 public: 00218 00222 typedef MultiTypeBlockMatrix<T1, T2, T3, T4, T5, T6, T7, T8, T9> type; 00223 00224 typedef typename mpl::at_c<T1,0>::type field_type; 00225 00229 template<typename T> 00230 void operator= (const T& newval) {MultiTypeBlockMatrix_Ident<mpl::size<type>::value,type,T>::equalize(*this, newval); } 00231 00235 template<typename X, typename Y> 00236 void mv (const X& x, Y& y) const { 00237 BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value); //make sure x's length matches row length 00238 BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value); //make sure y's length matches row count 00239 00240 y = 0; //reset y (for mv uses umv) 00241 MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::umv(y, *this, x); //iterate over all matrix elements 00242 } 00243 00247 template<typename X, typename Y> 00248 void umv (const X& x, Y& y) const { 00249 BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value); //make sure x's length matches row length 00250 BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value); //make sure y's length matches row count 00251 00252 MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::umv(y, *this, x); //iterate over all matrix elements 00253 } 00254 00258 template<typename X, typename Y> 00259 void mmv (const X& x, Y& y) const { 00260 BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value); //make sure x's length matches row length 00261 BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value); //make sure y's length matches row count 00262 00263 MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::mmv(y, *this, x); //iterate over all matrix elements 00264 } 00265 00267 template<typename AlphaType, typename X, typename Y> 00268 void usmv (const AlphaType& alpha, const X& x, Y& y) const { 00269 BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value); //make sure x's length matches row length 00270 BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value); //make sure y's length matches row count 00271 00272 MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::usmv(alpha,y, *this, x); //iterate over all matrix elements 00273 00274 } 00275 00276 00277 00278 }; 00279 00280 00281 00287 template<typename T1, typename T2, typename T3, typename T4, typename T5, 00288 typename T6, typename T7, typename T8, typename T9> 00289 std::ostream& operator<< (std::ostream& s, const MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9>& m) { 00290 static const int i = mpl::size<MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9> >::value; //row count 00291 static const int j = mpl::size< typename mpl::at_c<MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9>,0>::type >::value; //col count of first row 00292 MultiTypeBlockMatrix_Print<0,i,0,j,MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9> >::print(m); 00293 return s; 00294 } 00295 00296 00297 00298 00299 00300 //make algmeta_itsteps known 00301 template<int I> 00302 struct algmeta_itsteps; 00303 00304 00305 00306 00307 00308 00315 template<int I, int crow, int ccol, int remain_col> //MultiTypeBlockMatrix_Solver_Col: iterating over one row 00316 class MultiTypeBlockMatrix_Solver_Col { //calculating b- A[i][j]*x[j] 00317 public: 00321 template <typename Trhs, typename TVector, typename TMatrix, typename K> 00322 static void calc_rhs(const TMatrix& A, TVector& x, TVector& v, Trhs& b, const K& w) { 00323 fusion::at_c<ccol>( fusion::at_c<crow>(A) ).mmv( fusion::at_c<ccol>(x), b ); 00324 MultiTypeBlockMatrix_Solver_Col<I, crow, ccol+1, remain_col-1>::calc_rhs(A,x,v,b,w); //next column element 00325 } 00326 00327 }; 00328 template<int I, int crow, int ccol> //MultiTypeBlockMatrix_Solver_Col recursion end 00329 class MultiTypeBlockMatrix_Solver_Col<I,crow,ccol,0> { 00330 public: 00331 template <typename Trhs, typename TVector, typename TMatrix, typename K> 00332 static void calc_rhs(const TMatrix& A, TVector& x, TVector& v, Trhs& b, const K& w) {} 00333 }; 00334 00335 00336 00343 template<int I, int crow, int remain_row> 00344 class MultiTypeBlockMatrix_Solver { 00345 public: 00346 00350 template <typename TVector, typename TMatrix, typename K> 00351 static void dbgs(const TMatrix& A, TVector& x, const TVector& b, const K& w) { 00352 TVector xold(x); 00353 xold=x; //store old x values 00354 MultiTypeBlockMatrix_Solver<I,crow,remain_row>::dbgs(A,x,x,b,w); 00355 x *= w; 00356 x.axpy(1-w,xold); //improve x 00357 } 00358 template <typename TVector, typename TMatrix, typename K> 00359 static void dbgs(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) { 00360 typename mpl::at_c<TVector,crow>::type rhs; 00361 rhs = fusion::at_c<crow> (b); 00362 00363 MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation 00364 //solve on blocklevel I-1 00365 algmeta_itsteps<I-1>::dbgs(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(x),rhs,w); 00366 MultiTypeBlockMatrix_Solver<I,crow+1,remain_row-1>::dbgs(A,x,v,b,w); //next row 00367 } 00368 00369 00370 00374 template <typename TVector, typename TMatrix, typename K> 00375 static void bsorf(const TMatrix& A, TVector& x, const TVector& b, const K& w) { 00376 TVector v; 00377 v=x; //use latest x values in right side calculation 00378 MultiTypeBlockMatrix_Solver<I,crow,remain_row>::bsorf(A,x,v,b,w); 00379 00380 } 00381 template <typename TVector, typename TMatrix, typename K> //recursion over all matrix rows (A) 00382 static void bsorf(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) { 00383 typename mpl::at_c<TVector,crow>::type rhs; 00384 rhs = fusion::at_c<crow> (b); 00385 00386 MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation 00387 //solve on blocklevel I-1 00388 algmeta_itsteps<I-1>::bsorf(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w); 00389 fusion::at_c<crow>(x).axpy(w,fusion::at_c<crow>(v)); 00390 MultiTypeBlockMatrix_Solver<I,crow+1,remain_row-1>::bsorf(A,x,v,b,w); //next row 00391 } 00392 00396 template <typename TVector, typename TMatrix, typename K> 00397 static void bsorb(const TMatrix& A, TVector& x, const TVector& b, const K& w) { 00398 TVector v; 00399 v=x; //use latest x values in right side calculation 00400 MultiTypeBlockMatrix_Solver<I,crow,remain_row>::bsorb(A,x,v,b,w); 00401 00402 } 00403 template <typename TVector, typename TMatrix, typename K> //recursion over all matrix rows (A) 00404 static void bsorb(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) { 00405 typename mpl::at_c<TVector,crow>::type rhs; 00406 rhs = fusion::at_c<crow> (b); 00407 00408 MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation 00409 //solve on blocklevel I-1 00410 algmeta_itsteps<I-1>::bsorb(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w); 00411 fusion::at_c<crow>(x).axpy(w,fusion::at_c<crow>(v)); 00412 MultiTypeBlockMatrix_Solver<I,crow-1,remain_row-1>::bsorb(A,x,v,b,w); //next row 00413 } 00414 00415 00419 template <typename TVector, typename TMatrix, typename K> 00420 static void dbjac(const TMatrix& A, TVector& x, const TVector& b, const K& w) { 00421 TVector v(x); 00422 v=0; //calc new x in v 00423 MultiTypeBlockMatrix_Solver<I,crow,remain_row>::dbjac(A,x,v,b,w); 00424 x.axpy(w,v); //improve x 00425 } 00426 template <typename TVector, typename TMatrix, typename K> 00427 static void dbjac(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) { 00428 typename mpl::at_c<TVector,crow>::type rhs; 00429 rhs = fusion::at_c<crow> (b); 00430 00431 MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w); // calculate right side of equation 00432 //solve on blocklevel I-1 00433 algmeta_itsteps<I-1>::dbjac(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w); 00434 MultiTypeBlockMatrix_Solver<I,crow+1,remain_row-1>::dbjac(A,x,v,b,w); //next row 00435 } 00436 00437 00438 00439 00440 }; 00441 template<int I, int crow> //recursion end for remain_row = 0 00442 class MultiTypeBlockMatrix_Solver<I,crow,0> { 00443 public: 00444 template <typename TVector, typename TMatrix, typename K> 00445 static void dbgs(const TMatrix& A, TVector& x, TVector& v, 00446 const TVector& b, const K& w) {} 00447 00448 template <typename TVector, typename TMatrix, typename K> 00449 static void bsorf(const TMatrix& A, TVector& x, TVector& v, 00450 const TVector& b, const K& w) {} 00451 00452 template <typename TVector, typename TMatrix, typename K> 00453 static void bsorb(const TMatrix& A, TVector& x, TVector& v, 00454 const TVector& b, const K& w) {} 00455 00456 template <typename TVector, typename TMatrix, typename K> 00457 static void dbjac(const TMatrix& A, TVector& x, TVector& v, 00458 const TVector& b, const K& w) {} 00459 }; 00460 00461 } // end namespace 00462 00463 #endif // HAVE_BOOST_FUSION 00464 #endif // HAVE_BOOST 00465 #endif 00466