dune-istl  2.2.0
matrixmatrix.hh
Go to the documentation of this file.
00001 #ifndef DUNE_MATRIXMATRIX_HH
00002 #define DUNE_MATRIXMATRIX_HH
00003 #include <dune/istl/bcrsmatrix.hh>
00004 #include <dune/common/fmatrix.hh>
00005 #include <dune/common/tuples.hh>
00006 #include <dune/common/timer.hh>
00007 namespace Dune
00008 {  
00009 
00020   namespace
00021     {
00022       
00031         template<int b>
00032         struct NonzeroPatternTraverser
00033         {};    
00034 
00035       
00036       template<>
00037       struct NonzeroPatternTraverser<0>
00038       {
00039          template<class T,class A1, class A2, class F, int n, int m, int k>
00040          static void traverse(const Dune::BCRSMatrix<Dune::FieldMatrix<T,n,k>,A1>& A,
00041                        const Dune::BCRSMatrix<Dune::FieldMatrix<T,k,m>,A2>& B,
00042                        F& func)
00043         {
00044           typedef typename Dune::BCRSMatrix<Dune::FieldMatrix<T,k,m>,A2>::size_type size_type;
00045           
00046           if(A.M()!=B.N())
00047             DUNE_THROW(ISTLError, "The sizes of the matrices do not match: "<<A.M()<<"!="<<B.N());
00048           
00049           typedef typename Dune::BCRSMatrix<Dune::FieldMatrix<T,n,k>,A1>::ConstRowIterator Row;
00050           typedef typename Dune::BCRSMatrix<Dune::FieldMatrix<T,n,k>,A1>::ConstColIterator Col;
00051           typedef typename Dune::BCRSMatrix<Dune::FieldMatrix<T,k,m>,A2>::ConstRowIterator BRow;
00052           typedef typename Dune::BCRSMatrix<Dune::FieldMatrix<T,k,m>,A2>::ConstColIterator BCol;
00053           for(Row row= A.begin(); row != A.end(); ++row){
00054             // Loop over all column entries
00055             for(Col col = row->begin(); col != row->end(); ++col){
00056               // entry at i,k
00057               // search for all nonzeros in row k
00058               for(BCol bcol = B[col.index()].begin(); bcol != B[col.index()].end(); ++bcol){
00059                 func(*col, *bcol, row.index(), bcol.index());
00060               }
00061             }
00062           }
00063         }
00064         
00065       };
00066       
00067       template<>
00068       struct NonzeroPatternTraverser<1>
00069       {
00070         template<class T, class A1, class A2, class F, int n, int m, int k>
00071         static void traverse(const Dune::BCRSMatrix<Dune::FieldMatrix<T,k,n>,A1>& A,
00072                              const Dune::BCRSMatrix<Dune::FieldMatrix<T,k,m>,A2>& B,
00073                              F& func)
00074         {
00075           
00076           if(A.N()!=B.N())
00077             DUNE_THROW(ISTLError, "The sizes of the matrices do not match: "<<A.N()<<"!="<<B.N());
00078           
00079           typedef typename Dune::BCRSMatrix<Dune::FieldMatrix<T,k,n>,A1>::ConstRowIterator Row;
00080           typedef typename Dune::BCRSMatrix<Dune::FieldMatrix<T,k,n>,A1>::ConstColIterator Col;
00081           typedef typename Dune::BCRSMatrix<Dune::FieldMatrix<T,k,m>,A2>::ConstColIterator BCol;
00082           typedef typename Dune::BCRSMatrix<Dune::FieldMatrix<T,k,n>,A1>::size_type size_t1;
00083           typedef typename Dune::BCRSMatrix<Dune::FieldMatrix<T,k,m>,A2>::size_type size_t2;
00084           
00085           for(Row row=A.begin(); row!=A.end(); ++row){
00086             for(Col col=row->begin(); col!=row->end(); ++col){
00087               for(BCol bcol  = B[row.index()].begin(); bcol !=  B[row.index()].end(); ++bcol){
00088                 func(*col, *bcol, col.index(), bcol.index());
00089               }
00090             }
00091           }
00092         }
00093       };
00094       
00095         template<>
00096         struct NonzeroPatternTraverser<2>
00097         {
00098             template<class T, class A1, class A2, class F, int n, int m, int k>
00099             static void traverse(const BCRSMatrix<FieldMatrix<T,n,m>,A1>& mat, 
00100                 const BCRSMatrix<FieldMatrix<T,k,m>,A2>& matt,
00101                 F& func)
00102             {
00103                 if(mat.M()!=matt.M())
00104                     DUNE_THROW(ISTLError, "The sizes of the matrices do not match: "<<mat.N()<<"!="<<matt.N());
00105 
00106                 typedef typename BCRSMatrix<FieldMatrix<T,n,m>,A1>::ConstRowIterator row_iterator;
00107                 typedef typename BCRSMatrix<FieldMatrix<T,n,m>,A1>::ConstColIterator col_iterator; 
00108                 typedef typename BCRSMatrix<FieldMatrix<T,k,m>,A2>::ConstRowIterator row_iterator_t;
00109                 typedef typename BCRSMatrix<FieldMatrix<T,k,m>,A2>::ConstColIterator col_iterator_t;
00110         
00111                 for(row_iterator mrow=mat.begin(); mrow != mat.end(); ++mrow){
00112                     //iterate over the column entries
00113                     // mt is a transposed matrix crs therefore it is treated as a ccs matrix
00114                     // and the row_iterator iterates over the columns of the transposed matrix.
00115                     // search the row of the transposed matrix for an entry with the same index
00116                     // as the mcol iterator
00117           
00118                     for(row_iterator_t mtcol=matt.begin(); mtcol != matt.end(); ++mtcol, func.nextCol()){
00119                         //Search for col entries in mat that have a corrsponding row index in matt
00120                         // (i.e. corresponding col index in the as this is the transposed matrix
00121                         col_iterator_t mtrow=mtcol->begin();
00122             
00123                         for(col_iterator mcol=mrow->begin(); mcol != mrow->end(); ++mcol){
00124                             // search 
00125                             // TODO: This should probably be substituted by a binary search
00126                             for( ;mtrow != mtcol->end(); ++mtrow)
00127                                 if(mtrow.index()>=mcol.index())
00128                                     break;
00129                             if(mtrow != mtcol->end() && mtrow.index()==mcol.index()){
00130                                 func(*mcol, *mtrow, mtcol.index());
00131                                 // In some cases we only search for one pair, then we break here
00132                                 // and continue with the next column.
00133                                 if(F::do_break)
00134                                     break;
00135                             }
00136                         }
00137                     }
00138                     func.nextRow();
00139                 }
00140             } 
00141         };
00142 
00143       
00144 
00145         template<class T, class A, int n, int m>
00146         class SparsityPatternInitializer
00147         {
00148         public:
00149             enum{do_break=true};
00150             typedef typename BCRSMatrix<FieldMatrix<T,n,m>,A>::CreateIterator CreateIterator;
00151             typedef typename BCRSMatrix<FieldMatrix<T,n,m>,A>::size_type size_type;
00152           
00153             SparsityPatternInitializer(CreateIterator iter)
00154                 : rowiter(iter)
00155             {}
00156       
00157             template<class T1, class T2>
00158             void operator()(const T1& t1, const T2& t2, size_type j)
00159             {
00160                 rowiter.insert(j);
00161             }
00162       
00163             void nextRow()
00164             {
00165                 ++rowiter;
00166             }
00167             void nextCol()
00168             {}
00169       
00170         private:
00171             CreateIterator rowiter;
00172         };
00173 
00174 
00175       template<int transpose, class T, class TA, int n, int m>
00176       class MatrixInitializer
00177       {
00178       public:
00179         enum{do_break=true};
00180         typedef typename Dune::BCRSMatrix<FieldMatrix<T,n,m>,TA> Matrix;
00181         typedef typename Matrix::CreateIterator CreateIterator;
00182         typedef typename Matrix::size_type size_type;
00183 
00184             MatrixInitializer(Matrix& A_, size_type rows)
00185               : count(0), A(A_)
00186             {}
00187             template<class T1, class T2>
00188             void operator()(const T1& t1, const T2& t2, int j)
00189             {
00190                 ++count;
00191             }
00192       
00193             void nextCol()
00194             {}
00195       
00196             void nextRow()
00197             {}
00198       
00199             std::size_t nonzeros()
00200             {
00201                 return count;
00202             }
00203       
00204           template<class A1, class A2, int n2, int m2, int n3, int m3>
00205           void initPattern(const BCRSMatrix<FieldMatrix<T,n2,m2>,A1>& mat1, 
00206                            const BCRSMatrix<FieldMatrix<T,n3,m3>,A2>& mat2)
00207           {
00208             SparsityPatternInitializer<T, TA, n, m> sparsity(A.createbegin());
00209             NonzeroPatternTraverser<transpose>::traverse(mat1,mat2,sparsity);
00210           }
00211           
00212       private:
00213         std::size_t count;
00214         Matrix& A;
00215       };
00216 
00217       template<class T, class TA, int n, int m>
00218       class MatrixInitializer<1,T,TA,n,m>
00219       {
00220       public:
00221         enum{do_break=false};
00222         typedef Dune::BCRSMatrix<Dune::FieldMatrix<T,n,m>,TA> Matrix;
00223         typedef typename Matrix::CreateIterator CreateIterator;
00224         typedef typename Matrix::size_type size_type;
00225           
00226         MatrixInitializer(Matrix& A_, size_type rows)
00227           :  A(A_), entries(rows)
00228         {}
00229         
00230         template<class T1, class T2>
00231         void operator()(const T1& t1, const T2& t2, size_type i, size_type j)
00232         {
00233           entries[i].insert(j);
00234         }
00235 
00236         void nextCol()
00237         {}
00238         
00239         size_type nonzeros()
00240         {
00241           size_type nnz=0;
00242           typedef typename std::vector<std::set<size_t> >::const_iterator Iter;
00243           for(Iter iter = entries.begin(); iter != entries.end(); ++iter)
00244             nnz+=(*iter).size();
00245           return nnz;
00246         }
00247         template<class A1, class A2, int n2, int m2, int n3, int m3>
00248         void initPattern(const BCRSMatrix<FieldMatrix<T,n2,m2>,A1>& mat1, 
00249                          const BCRSMatrix<FieldMatrix<T,n3,m3>,A2>& mat2)
00250         {
00251           typedef typename std::vector<std::set<size_t> >::const_iterator Iter;
00252           CreateIterator citer = A.createbegin();
00253           for(Iter iter = entries.begin(); iter != entries.end(); ++iter, ++citer){
00254             typedef std::set<size_t>::const_iterator SetIter;
00255             for(SetIter index=iter->begin(); index != iter->end(); ++index)
00256               citer.insert(*index);
00257           }
00258         }
00259         
00260       private:
00261         Matrix& A;
00262         std::vector<std::set<size_t> > entries;
00263       };
00264 
00265       template<class T, class TA, int n, int m>
00266       struct MatrixInitializer<0,T,TA,n,m>
00267         : public MatrixInitializer<1,T,TA,n,m>
00268       {
00269         MatrixInitializer(Dune::BCRSMatrix<Dune::FieldMatrix<T,n,m>,TA>& A_, 
00270                           typename Dune::BCRSMatrix<Dune::FieldMatrix<T,n,m>,TA>::size_type rows)
00271           :MatrixInitializer<1,T,TA,n,m>(A_,rows)
00272         {}
00273       };
00274       
00275     
00276         template<class T, class T1, class T2, int n, int m, int k>
00277         void addMatMultTransposeMat(FieldMatrix<T,n,k>& res, const FieldMatrix<T1,n,m>& mat, 
00278             const FieldMatrix<T2,k,m>& matt)
00279         {
00280             typedef typename FieldMatrix<T,n,k>::size_type size_type;
00281           
00282             for(size_type row=0; row<n; ++row)
00283                 for(size_type col=0; col<k;++col){
00284                     for(size_type i=0; i < m; ++i)
00285                         res[row][col]+=mat[row][i]*matt[col][i];
00286                 }
00287         }
00288 
00289         template<class T, class T1, class T2, int n, int m, int k>
00290         void addTransposeMatMultMat(FieldMatrix<T,n,k>& res, const FieldMatrix<T1,m,n>& mat, 
00291             const FieldMatrix<T2,m,k>& matt)
00292         {
00293             typedef typename FieldMatrix<T,n,k>::size_type size_type;
00294             for(size_type i=0; i<m; ++i)
00295                 for(size_type row=0; row<n;++row){
00296                     for(size_type col=0; col < k; ++col)
00297                         res[row][col]+=mat[i][row]*matt[i][col];
00298                 }
00299         }
00300 
00301         template<class T, class T1, class T2, int n, int m, int k>
00302         void addMatMultMat(FieldMatrix<T,n,m>& res, const FieldMatrix<T1,n,k>& mat, 
00303             const FieldMatrix<T2,k,m>& matt)
00304         {
00305             typedef typename FieldMatrix<T,n,k>::size_type size_type;
00306             for(size_type row=0; row<n; ++row)
00307                 for(size_type col=0; col<m;++col){
00308                     for(size_type i=0; i < k; ++i)
00309                         res[row][col]+=mat[row][i]*matt[i][col];
00310                 }
00311         }
00312 
00313     
00314         template<class T, class A, int n, int m>
00315         class EntryAccumulatorFather
00316         {
00317         public:
00318             enum{do_break=false};
00319             typedef BCRSMatrix<FieldMatrix<T,n,m>,A> Matrix;
00320             typedef typename Matrix::RowIterator Row;
00321             typedef typename Matrix::ColIterator Col;
00322       
00323             EntryAccumulatorFather(Matrix& mat_)
00324                 :mat(mat_), row(mat.begin())
00325             {
00326                 mat=0;
00327                 col=row->begin();
00328             }
00329             void nextRow()
00330             {
00331                 ++row;
00332                 if(row!=mat.end())
00333                     col=row->begin();
00334             }
00335       
00336             void nextCol()
00337             {
00338                 ++this->col;
00339             }
00340         protected:
00341           Matrix& mat;
00342         private:
00343             Row row;
00344         protected:
00345             Col col;
00346         };
00347     
00348         template<class T, class A, int n, int m, int transpose>
00349         class EntryAccumulator 
00350             : public EntryAccumulatorFather<T,A,n,m>
00351         {
00352         public:      
00353             typedef BCRSMatrix<FieldMatrix<T,n,m>,A> Matrix;
00354             typedef typename Matrix::size_type size_type;
00355           
00356             EntryAccumulator(Matrix& mat_)
00357                 : EntryAccumulatorFather<T,A,n,m>(mat_)
00358             {}
00359       
00360             template<class T1, class T2>
00361             void operator()(const T1& t1, const T2& t2, size_type i)
00362             {
00363               assert(this->col.index()==i);
00364               addMatMultMat(*(this->col),t1,t2);
00365             }   
00366         };
00367 
00368       template<class T, class A, int n, int m>
00369         class EntryAccumulator<T,A,n,m,0>
00370             : public EntryAccumulatorFather<T,A,n,m>
00371         {
00372         public:      
00373             typedef BCRSMatrix<FieldMatrix<T,n,m>,A> Matrix;
00374             typedef typename Matrix::size_type size_type;
00375       
00376             EntryAccumulator(Matrix& mat_)
00377                 : EntryAccumulatorFather<T,A,n,m>(mat_)
00378             {}
00379       
00380             template<class T1, class T2>
00381             void operator()(const T1& t1, const T2& t2, size_type i, size_type j)
00382             {
00383                 addMatMultMat(this->mat[i][j], t1, t2);
00384             }
00385         };
00386 
00387         template<class T, class A, int n, int m>
00388         class EntryAccumulator<T,A,n,m,1>
00389             : public EntryAccumulatorFather<T,A,n,m>
00390         {
00391         public:      
00392             typedef BCRSMatrix<FieldMatrix<T,n,m>,A> Matrix;
00393             typedef typename Matrix::size_type size_type;
00394       
00395             EntryAccumulator(Matrix& mat_)
00396                 : EntryAccumulatorFather<T,A,n,m>(mat_)
00397             {}
00398       
00399             template<class T1, class T2>
00400             void operator()(const T1& t1, const T2& t2, size_type i, size_type j)
00401             {
00402                 addTransposeMatMultMat(this->mat[i][j], t1, t2);
00403             }
00404         };
00405 
00406         template<class T, class A, int n, int m>
00407         class EntryAccumulator<T,A,n,m,2>
00408             : public EntryAccumulatorFather<T,A,n,m>
00409         {
00410         public:      
00411             typedef BCRSMatrix<FieldMatrix<T,n,m>,A> Matrix;
00412             typedef typename Matrix::size_type size_type;
00413       
00414             EntryAccumulator(Matrix& mat_)
00415                 : EntryAccumulatorFather<T,A,n,m>(mat_)
00416             {}
00417       
00418             template<class T1, class T2>
00419             void operator()(const T1& t1, const T2& t2, size_type i)
00420             {
00421                 assert(this->col.index()==i);
00422                 addMatMultTransposeMat(*this->col,t1,t2);
00423             }
00424         };
00425     
00426     
00427         template<int transpose>
00428         struct SizeSelector
00429         {
00430         };
00431       
00432         template<>
00433         struct SizeSelector<0>
00434         {
00435             template<class M1, class M2>
00436             static tuple<typename M1::size_type, typename M2::size_type> 
00437             size(const M1& m1, const M2& m2)
00438             {
00439                 return make_tuple(m1.N(), m2.M());
00440             }
00441         };
00442       
00443         template<>
00444         struct SizeSelector<1>
00445         {
00446             template<class M1, class M2>
00447             static tuple<typename M1::size_type, typename M2::size_type> 
00448             size(const M1& m1, const M2& m2)
00449             {
00450                 return make_tuple(m1.M(), m2.M());
00451             }
00452         };
00453       
00454       
00455         template<>
00456         struct SizeSelector<2>
00457         {
00458             template<class M1, class M2>
00459             static tuple<typename M1::size_type, typename M2::size_type> 
00460             size(const M1& m1, const M2& m2)
00461             {
00462                 return make_tuple(m1.N(), m2.N());
00463             }
00464         };
00465       
00466         template<int transpose, class T, class A, class A1, class A2, int n1, int m1, int n2, int m2, int n3, int m3>
00467         void matMultMat(BCRSMatrix<FieldMatrix<T,n1,m1>,A>& res, const BCRSMatrix<FieldMatrix<T,n2,m2>,A1>& mat1, 
00468             const BCRSMatrix<FieldMatrix<T,n3,m3>,A2>& mat2)
00469         {    
00470             // First step is to count the number of nonzeros
00471             typename BCRSMatrix<FieldMatrix<T,n1,m1>,A>::size_type rows, cols;
00472             tie(rows,cols)=SizeSelector<transpose>::size(mat1, mat2);
00473             MatrixInitializer<transpose,T,A,n1,m1> patternInit(res, rows);
00474             Timer timer;
00475             NonzeroPatternTraverser<transpose>::traverse(mat1,mat2,patternInit);
00476             res.setSize(rows, cols, patternInit.nonzeros());
00477             res.setBuildMode(BCRSMatrix<FieldMatrix<T,n1,m1>,A>::row_wise);
00478             
00479             std::cout<<"Counting nonzeros took "<<timer.elapsed()<<std::endl;
00480             timer.reset();
00481             
00482             // Second step is to allocate the storage for the result and initialize the nonzero pattern
00483             patternInit.initPattern(mat1, mat2);
00484                     
00485             std::cout<<"Setting up sparsity pattern took "<<timer.elapsed()<<std::endl;
00486             timer.reset();
00487             // As a last step calculate the entries
00488             EntryAccumulator<T,A,n1,m1, transpose> entriesAccu(res);
00489             NonzeroPatternTraverser<transpose>::traverse(mat1,mat2,entriesAccu);
00490             std::cout<<"Calculating entries took "<<timer.elapsed()<<std::endl;
00491         }
00492     
00493     }
00494 
00502   template<typename M1, typename M2>
00503   struct MatMultMatResult
00504   {
00505   };
00506 
00507   template<typename T, int n, int k, int m>
00508   struct MatMultMatResult<FieldMatrix<T,n,k>,FieldMatrix<T,k,m> >
00509   {
00510     typedef FieldMatrix<T,n,m> type;
00511   };
00512   
00513   template<typename T, typename A, typename A1, int n, int k, int m>
00514   struct MatMultMatResult<BCRSMatrix<FieldMatrix<T,n,k>,A >,BCRSMatrix<FieldMatrix<T,k,m>,A1 > >
00515   {
00516     typedef BCRSMatrix<typename MatMultMatResult<FieldMatrix<T,n,k>,
00517                                                  FieldMatrix<T,k,m> >::type,A> type;
00518   };
00519 
00528     template<class T, class A, class A1, class A2, int n, int m, int k>
00529     void matMultTransposeMat(BCRSMatrix<FieldMatrix<T,n,k>,A>& res, const BCRSMatrix<FieldMatrix<T,n,m>,A1>& mat, 
00530         const BCRSMatrix<FieldMatrix<T,k,m>,A2>& matt, bool tryHard=false)
00531     {
00532         matMultMat<2>(res,mat, matt);
00533     }
00534 
00543     template<class T, class A, class A1, class A2, int n, int m, int k>
00544     void matMultMat(BCRSMatrix<FieldMatrix<T,n,m>,A>& res, const BCRSMatrix<FieldMatrix<T,n,k>,A1>& mat, 
00545         const BCRSMatrix<FieldMatrix<T,k,m>,A2>& matt, bool tryHard=false)
00546     {
00547         matMultMat<0>(res,mat, matt);
00548     }
00549 
00558     template<class T, class A, class A1, class A2, int n, int m, int k>
00559     void transposeMatMultMat(BCRSMatrix<FieldMatrix<T,n,m>,A>& res, const BCRSMatrix<FieldMatrix<T,k,n>,A1>& mat, 
00560         const BCRSMatrix<FieldMatrix<T,k,m>,A2>& matt, bool tryHard=false)
00561     {
00562         matMultMat<1>(res,mat, matt);
00563     }
00564 
00565 }
00566 #endif
00567