dune-istl  2.2.0
multitypeblockmatrix.hh
Go to the documentation of this file.
00001 #ifndef DUNE_MultiTypeMATRIX_HH
00002 #define DUNE_MultiTypeMATRIX_HH
00003 
00004 #include<cmath>
00005 #include<iostream>
00006 
00007 #include "istlexception.hh"
00008 
00009 #if HAVE_BOOST
00010 #ifdef HAVE_BOOST_FUSION
00011 
00012 #include <boost/fusion/sequence.hpp>
00013 #include <boost/fusion/container.hpp>
00014 #include <boost/fusion/iterator.hpp>
00015 #include <boost/typeof/typeof.hpp>
00016 #include <boost/fusion/algorithm.hpp>
00017 
00018 namespace mpl=boost::mpl;
00019 namespace fusion=boost::fusion;
00020 
00021 // forward decl
00022 namespace Dune
00023 {
00024     template<typename T1, typename T2=fusion::void_, typename T3=fusion::void_, typename T4=fusion::void_,
00025              typename T5=fusion::void_, typename T6=fusion::void_, typename T7=fusion::void_,
00026              typename T8=fusion::void_, typename T9=fusion::void_>
00027     class MultiTypeBlockMatrix;
00028 
00029     template<int I, int crow, int remain_row>
00030     class MultiTypeBlockMatrix_Solver;
00031 }
00032 
00033 #include "gsetc.hh"
00034 
00035 namespace Dune {
00036 
00054   template<int crow, int remain_rows, int ccol, int remain_cols, 
00055            typename TMatrix>
00056   class MultiTypeBlockMatrix_Print {
00057   public:
00058 
00062     static void print(const TMatrix& m) {
00063       std::cout << "\t(" << crow << ", " << ccol << "): \n" << fusion::at_c<ccol>( fusion::at_c<crow>(m));
00064       MultiTypeBlockMatrix_Print<crow,remain_rows,ccol+1,remain_cols-1,TMatrix>::print(m);         //next column
00065     }
00066   };
00067   template<int crow, int remain_rows, int ccol, typename TMatrix> //specialization for remain_cols=0
00068   class MultiTypeBlockMatrix_Print<crow,remain_rows,ccol,0,TMatrix> {
00069   public: static void print(const TMatrix& m) {
00070     static const int xlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value;
00071     MultiTypeBlockMatrix_Print<crow+1,remain_rows-1,0,xlen,TMatrix>::print(m);                   //next row
00072   }
00073   };
00074 
00075   template<int crow, int ccol, int remain_cols, typename TMatrix> //recursion end: specialization for remain_rows=0
00076   class MultiTypeBlockMatrix_Print<crow,0,ccol,remain_cols,TMatrix> {
00077   public: 
00078     static void print(const TMatrix& m) 
00079     {std::cout << std::endl;} 
00080   };
00081 
00082 
00083 
00084   //make MultiTypeBlockVector_Ident known (for MultiTypeBlockMatrix_Ident)
00085   template<int count, typename T1, typename T2>
00086   class MultiTypeBlockVector_Ident;
00087 
00088 
00101   template<int rowcount, typename T1, typename T2>
00102   class MultiTypeBlockMatrix_Ident {
00103   public:
00104 
00109     static void equalize(T1& a, const T2& b) {
00110       MultiTypeBlockVector_Ident< mpl::size< typename mpl::at_c<T1,rowcount-1>::type >::value ,T1,T2>::equalize(a,b);              //rows are cvectors
00111       MultiTypeBlockMatrix_Ident<rowcount-1,T1,T2>::equalize(a,b);         //iterate over rows
00112     }
00113   };
00114 
00115   //recursion end for rowcount=0
00116   template<typename T1, typename T2>
00117   class MultiTypeBlockMatrix_Ident<0,T1,T2> {
00118   public: 
00119     static void equalize (T1& a, const T2& b) 
00120     {} 
00121   };
00122 
00128   template<int crow, int remain_rows, int ccol, int remain_cols, 
00129            typename TVecY, typename TMatrix, typename TVecX>
00130   class MultiTypeBlockMatrix_VectMul {
00131   public:
00132 
00136     static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {                   
00137       fusion::at_c<ccol>( fusion::at_c<crow>(A) ).umv( fusion::at_c<ccol>(x), fusion::at_c<crow>(y) );
00138       MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::umv(y, A, x);
00139     }
00140 
00144     static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {                   
00145       fusion::at_c<ccol>( fusion::at_c<crow>(A) ).mmv( fusion::at_c<ccol>(x), fusion::at_c<crow>(y) );
00146       MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::mmv(y, A, x);
00147     }
00148 
00149     template<typename AlphaType>
00150     static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {                  
00151       fusion::at_c<ccol>( fusion::at_c<crow>(A) ).usmv(alpha, fusion::at_c<ccol>(x), fusion::at_c<crow>(y) );
00152       MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol+1,remain_cols-1,TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x);
00153     }
00154 
00155                                 
00156   };
00157 
00158   //specialization for remain_cols = 0
00159   template<int crow, int remain_rows,int ccol, typename TVecY, 
00160            typename TMatrix, typename TVecX>
00161   class MultiTypeBlockMatrix_VectMul<crow,remain_rows,ccol,0,TVecY,TMatrix,TVecX> {                                    //start iteration over next row
00162         
00163   public:
00167     static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {
00168       static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value;
00169       MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::umv(y, A, x);
00170     }
00171 
00175     static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {
00176       static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value;
00177       MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::mmv(y, A, x);
00178     }
00179 
00180     template <typename AlphaType>
00181     static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {
00182       static const int rowlen = mpl::size< typename mpl::at_c<TMatrix,crow>::type >::value;
00183       MultiTypeBlockMatrix_VectMul<crow+1,remain_rows-1,0,rowlen,TVecY,TMatrix,TVecX>::usmv(alpha,y, A, x);
00184     }
00185   };
00186 
00187    //specialization for remain_rows = 0
00188   template<int crow, int ccol, int remain_cols, typename TVecY, 
00189            typename TMatrix, typename TVecX>
00190   class MultiTypeBlockMatrix_VectMul<crow,0,ccol,remain_cols,TVecY,TMatrix,TVecX> { 
00191     //end recursion
00192   public:
00193     static void umv(TVecY& y, const TMatrix& A, const TVecX& x) {}
00194     static void mmv(TVecY& y, const TMatrix& A, const TVecX& x) {}
00195 
00196     template<typename AlphaType>
00197     static void usmv(const AlphaType& alpha, TVecY& y, const TMatrix& A, const TVecX& x) {}
00198   };
00199 
00200 
00201 
00202 
00203 
00204 
00213   template<typename T1, typename T2, typename T3, typename T4,
00214            typename T5, typename T6, typename T7, typename T8, typename T9>
00215   class MultiTypeBlockMatrix : public fusion::vector<T1, T2, T3, T4, T5, T6, T7, T8, T9> {
00216 
00217   public:
00218 
00222     typedef MultiTypeBlockMatrix<T1, T2, T3, T4, T5, T6, T7, T8, T9> type;
00223 
00224     typedef typename mpl::at_c<T1,0>::type field_type;
00225 
00229     template<typename T>
00230     void operator= (const T& newval) {MultiTypeBlockMatrix_Ident<mpl::size<type>::value,type,T>::equalize(*this, newval); }
00231 
00235     template<typename X, typename Y>
00236     void mv (const X& x, Y& y) const {
00237       BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value);       //make sure x's length matches row length
00238       BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value);     //make sure y's length matches row count
00239 
00240       y = 0;                                                                  //reset y (for mv uses umv)
00241       MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::umv(y, *this, x);    //iterate over all matrix elements
00242     }
00243 
00247     template<typename X, typename Y>
00248     void umv (const X& x, Y& y) const {
00249       BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value);       //make sure x's length matches row length
00250       BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value);     //make sure y's length matches row count
00251 
00252       MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::umv(y, *this, x);    //iterate over all matrix elements
00253     }
00254 
00258     template<typename X, typename Y>
00259     void mmv (const X& x, Y& y) const {
00260       BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value);       //make sure x's length matches row length
00261       BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value);     //make sure y's length matches row count
00262 
00263       MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::mmv(y, *this, x);    //iterate over all matrix elements
00264     }
00265 
00267     template<typename AlphaType, typename X, typename Y>
00268     void usmv (const AlphaType& alpha, const X& x, Y& y) const {
00269       BOOST_STATIC_ASSERT(mpl::size<X>::value == mpl::size<T1>::value);       //make sure x's length matches row length
00270       BOOST_STATIC_ASSERT(mpl::size<Y>::value == mpl::size<type>::value);     //make sure y's length matches row count
00271 
00272       MultiTypeBlockMatrix_VectMul<0,mpl::size<type>::value,0,mpl::size<T1>::value,Y,type,X>::usmv(alpha,y, *this, x);     //iterate over all matrix elements
00273         
00274     }
00275 
00276 
00277 
00278   };
00279 
00280 
00281 
00287   template<typename T1, typename T2, typename T3, typename T4, typename T5, 
00288            typename T6, typename T7, typename T8, typename T9>
00289   std::ostream& operator<< (std::ostream& s, const MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9>& m) {
00290     static const int i = mpl::size<MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9> >::value;            //row count
00291     static const int j = mpl::size< typename mpl::at_c<MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9>,0>::type >::value;       //col count of first row
00292     MultiTypeBlockMatrix_Print<0,i,0,j,MultiTypeBlockMatrix<T1,T2,T3,T4,T5,T6,T7,T8,T9> >::print(m);
00293     return s;
00294   }
00295 
00296 
00297 
00298 
00299 
00300   //make algmeta_itsteps known
00301   template<int I>
00302   struct algmeta_itsteps;
00303 
00304 
00305 
00306 
00307 
00308 
00315   template<int I, int crow, int ccol, int remain_col>                             //MultiTypeBlockMatrix_Solver_Col: iterating over one row
00316   class MultiTypeBlockMatrix_Solver_Col {                                                      //calculating b- A[i][j]*x[j]
00317   public:
00321     template <typename Trhs, typename TVector, typename TMatrix, typename K>
00322     static void calc_rhs(const TMatrix& A, TVector& x, TVector& v, Trhs& b, const K& w) {
00323       fusion::at_c<ccol>( fusion::at_c<crow>(A) ).mmv( fusion::at_c<ccol>(x), b );
00324       MultiTypeBlockMatrix_Solver_Col<I, crow, ccol+1, remain_col-1>::calc_rhs(A,x,v,b,w); //next column element
00325     }
00326 
00327   };
00328   template<int I, int crow, int ccol>                                             //MultiTypeBlockMatrix_Solver_Col recursion end
00329   class MultiTypeBlockMatrix_Solver_Col<I,crow,ccol,0> {
00330   public:
00331     template <typename Trhs, typename TVector, typename TMatrix, typename K>
00332     static void calc_rhs(const TMatrix& A, TVector& x, TVector& v, Trhs& b, const K& w) {}
00333   };
00334 
00335 
00336 
00343   template<int I, int crow, int remain_row>
00344   class MultiTypeBlockMatrix_Solver {
00345   public:
00346 
00350     template <typename TVector, typename TMatrix, typename K>
00351     static void dbgs(const TMatrix& A, TVector& x, const TVector& b, const K& w) {
00352       TVector xold(x);
00353       xold=x;                                                         //store old x values
00354       MultiTypeBlockMatrix_Solver<I,crow,remain_row>::dbgs(A,x,x,b,w);
00355       x *= w;
00356       x.axpy(1-w,xold);                                                       //improve x
00357     }
00358     template <typename TVector, typename TMatrix, typename K>
00359     static void dbgs(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
00360       typename mpl::at_c<TVector,crow>::type rhs;
00361       rhs = fusion::at_c<crow> (b);
00362 
00363       MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w);  // calculate right side of equation
00364       //solve on blocklevel I-1
00365       algmeta_itsteps<I-1>::dbgs(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(x),rhs,w);
00366       MultiTypeBlockMatrix_Solver<I,crow+1,remain_row-1>::dbgs(A,x,v,b,w); //next row
00367     }
00368 
00369 
00370 
00374     template <typename TVector, typename TMatrix, typename K>
00375     static void bsorf(const TMatrix& A, TVector& x, const TVector& b, const K& w) {
00376       TVector v;
00377       v=x;                                                            //use latest x values in right side calculation
00378       MultiTypeBlockMatrix_Solver<I,crow,remain_row>::bsorf(A,x,v,b,w);
00379                 
00380     }
00381     template <typename TVector, typename TMatrix, typename K>               //recursion over all matrix rows (A)
00382     static void bsorf(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
00383       typename mpl::at_c<TVector,crow>::type rhs;
00384       rhs = fusion::at_c<crow> (b);
00385 
00386       MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w);  // calculate right side of equation
00387       //solve on blocklevel I-1
00388       algmeta_itsteps<I-1>::bsorf(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w);
00389       fusion::at_c<crow>(x).axpy(w,fusion::at_c<crow>(v));
00390       MultiTypeBlockMatrix_Solver<I,crow+1,remain_row-1>::bsorf(A,x,v,b,w);        //next row
00391     }
00392 
00396     template <typename TVector, typename TMatrix, typename K>
00397     static void bsorb(const TMatrix& A, TVector& x, const TVector& b, const K& w) {
00398       TVector v;
00399       v=x;                                                            //use latest x values in right side calculation
00400       MultiTypeBlockMatrix_Solver<I,crow,remain_row>::bsorb(A,x,v,b,w);
00401                 
00402     }
00403     template <typename TVector, typename TMatrix, typename K>               //recursion over all matrix rows (A)
00404     static void bsorb(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
00405       typename mpl::at_c<TVector,crow>::type rhs;
00406       rhs = fusion::at_c<crow> (b);
00407 
00408       MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w);  // calculate right side of equation
00409       //solve on blocklevel I-1
00410       algmeta_itsteps<I-1>::bsorb(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w);
00411       fusion::at_c<crow>(x).axpy(w,fusion::at_c<crow>(v));
00412       MultiTypeBlockMatrix_Solver<I,crow-1,remain_row-1>::bsorb(A,x,v,b,w);        //next row
00413     }
00414 
00415 
00419     template <typename TVector, typename TMatrix, typename K>
00420     static void dbjac(const TMatrix& A, TVector& x, const TVector& b, const K& w) {
00421       TVector v(x);
00422       v=0;                                                            //calc new x in v
00423       MultiTypeBlockMatrix_Solver<I,crow,remain_row>::dbjac(A,x,v,b,w);
00424       x.axpy(w,v);                                                    //improve x
00425     }
00426     template <typename TVector, typename TMatrix, typename K>
00427     static void dbjac(const TMatrix& A, TVector& x, TVector& v, const TVector& b, const K& w) {
00428       typename mpl::at_c<TVector,crow>::type rhs;
00429       rhs = fusion::at_c<crow> (b);
00430 
00431       MultiTypeBlockMatrix_Solver_Col<I,crow,0, mpl::size<typename mpl::at_c<TMatrix,crow>::type>::value>::calc_rhs(A,x,v,rhs,w);  // calculate right side of equation
00432       //solve on blocklevel I-1
00433       algmeta_itsteps<I-1>::dbjac(fusion::at_c<crow>( fusion::at_c<crow>(A)), fusion::at_c<crow>(v),rhs,w);
00434       MultiTypeBlockMatrix_Solver<I,crow+1,remain_row-1>::dbjac(A,x,v,b,w);        //next row
00435     }
00436 
00437 
00438 
00439 
00440   };
00441   template<int I, int crow>                                                       //recursion end for remain_row = 0
00442   class MultiTypeBlockMatrix_Solver<I,crow,0> {
00443   public:
00444     template <typename TVector, typename TMatrix, typename K>
00445     static void dbgs(const TMatrix& A, TVector& x, TVector& v, 
00446                      const TVector& b, const K& w) {}
00447 
00448     template <typename TVector, typename TMatrix, typename K>
00449     static void bsorf(const TMatrix& A, TVector& x, TVector& v, 
00450                       const TVector& b, const K& w) {}
00451 
00452     template <typename TVector, typename TMatrix, typename K>
00453     static void bsorb(const TMatrix& A, TVector& x, TVector& v, 
00454                       const TVector& b, const K& w) {}
00455 
00456     template <typename TVector, typename TMatrix, typename K>
00457     static void dbjac(const TMatrix& A, TVector& x, TVector& v, 
00458                       const TVector& b, const K& w) {}
00459   };
00460 
00461 } // end namespace
00462 
00463 #endif // HAVE_BOOST_FUSION
00464 #endif // HAVE_BOOST
00465 #endif
00466