77class  FiniteElement :
88    """Abstract finite element.""" 
99
10-     def  __init__ (self , reference , basis , dofs , domain_dim , range_dim ):
10+     def  __init__ (self , reference , basis , dofs , domain_dim , range_dim ,
11+                  range_shape = None ):
1112        assert  len (basis ) ==  len (dofs )
1213        self .reference  =  reference 
1314        self .basis  =  basis 
1415        self .dofs  =  dofs 
1516        self .domain_dim  =  domain_dim 
1617        self .range_dim  =  range_dim 
18+         self .range_shape  =  range_shape 
1719        self .space_dim  =  len (dofs )
1820        self ._basis_functions  =  None 
21+         self ._reshaped_basis_functions  =  None 
1922
20-     def  get_basis_functions (self ):
23+     def  get_polynomial_basis (self , reshape = True ):
24+         """Get the polynomial basis for the element.""" 
25+         if  reshape  and  self .range_shape  is  not None :
26+             if  len (self .range_shape ) !=  2 :
27+                 raise  NotImplementedError 
28+             assert  self .range_shape [0 ] *  self .range_shape [1 ] ==  self .range_dim 
29+             return  [sympy .Matrix (
30+                 [b [i  *  self .range_shape [1 ]: (i  +  1 ) *  self .range_shape [1 ]]
31+                  for  i  in  range (self .range_shape [0 ])]) for  b  in  self .basis ]
32+ 
33+         return  self .basis 
34+ 
35+     def  get_basis_functions (self , reshape = True ):
2136        """Get the basis functions of the element.""" 
2237        if  self ._basis_functions  is  None :
2338            mat  =  []
@@ -44,6 +59,14 @@ def get_basis_functions(self):
4459                            b [j ] +=  c  *  d_j 
4560                    self ._basis_functions .append (b )
4661
62+         if  reshape  and  self .range_shape  is  not None :
63+             if  len (self .range_shape ) !=  2 :
64+                 raise  NotImplementedError 
65+             assert  self .range_shape [0 ] *  self .range_shape [1 ] ==  self .range_dim 
66+             return  [sympy .Matrix (
67+                 [b [i  *  self .range_shape [1 ]: (i  +  1 ) *  self .range_shape [1 ]]
68+                  for  i  in  range (self .range_shape [0 ])]) for  b  in  self ._basis_functions ]
69+ 
4770        return  self ._basis_functions 
4871
4972    def  tabulate_basis (self , points , order = "xyzxyz" ):
@@ -52,7 +75,7 @@ def tabulate_basis(self, points, order="xyzxyz"):
5275            output  =  []
5376            for  p  in  points :
5477                row  =  []
55-                 for  b  in  self .get_basis_functions ():
78+                 for  b  in  self .get_basis_functions (False ):
5679                    row .append (subs (b , x , p ))
5780                output .append (row )
5881            return  output 
@@ -62,15 +85,15 @@ def tabulate_basis(self, points, order="xyzxyz"):
6285            for  p  in  points :
6386                row  =  []
6487                for  d  in  range (self .range_dim ):
65-                     for  b  in  self .get_basis_functions ():
88+                     for  b  in  self .get_basis_functions (False ):
6689                        row .append (subs (b [d ], x , p ))
6790                output .append (row )
6891            return  output 
6992        if  order  ==  "xyzxyz" :
7093            output  =  []
7194            for  p  in  points :
7295                row  =  []
73-                 for  b  in  self .get_basis_functions ():
96+                 for  b  in  self .get_basis_functions (False ):
7497                    for  i  in  subs (b , x , p ):
7598                        row .append (i )
7699                output .append (row )
@@ -79,7 +102,7 @@ def tabulate_basis(self, points, order="xyzxyz"):
79102            output  =  []
80103            for  p  in  points :
81104                row  =  []
82-                 for  b  in  self .get_basis_functions ():
105+                 for  b  in  self .get_basis_functions (False ):
83106                    row .append (subs (b , x , p ))
84107                output .append (row )
85108            return  output 
0 commit comments