Skip to content

Commit 4850da9

Browse files
Rewrite vtp plot to TubeFilter(File)
1 parent 927a29d commit 4850da9

File tree

2 files changed

+84
-51
lines changed

2 files changed

+84
-51
lines changed

graphnics/plot.py

Lines changed: 79 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,102 @@
11
import networkx as nx
22
from .fenics_graph import *
33
from vtk import *
4+
import os
45

56
'''
6-
Write .vtp files for functions defined on the graph
7-
'''
7+
Overloaded File class for writing .vtp files for functions defined on the graph
8+
This allows for using the TubeFilter in paraview
89
10+
TODO: Allow for writing time-dependent functions
11+
'''
912

10-
def write_vtp(G, functions=[], fname='plot.vtp'):
11-
'''
12-
Write file for plotting graph mesh and associated scalar fields in paraview
13-
14-
Args:
15-
G (FenicsGraph): graph
16-
functions (list of tuples): fenics functions to plot and their name
17-
fname (str): file name for plot
13+
class TubeFile(File):
14+
def __init__(self, G, fname, **kwargs):
15+
"""
16+
.vtp file with network function and radius, made for TubeFilter in paraview
17+
18+
Args:
19+
G (FenicsGraph): graph with mesh
20+
fname (str): location and name for file
21+
22+
Usage:
23+
>> G = make_Y_bifurcation(dim=3)
24+
>> V = FunctionSpace(G.global_mesh, 'CG', 1)
25+
>> radius_i = interpolate(Expression('x[1]+0.1*x[0]', degree=2), V)
26+
>> f_i = interpolate(Expression('x[0]', degree=2))
27+
>> f_i.rename('f', '0.0')
28+
>> TubeFile(G, 'test.vtp') << (val, radius_i)
29+
"""
30+
31+
f_name, f_ext = os.path.splitext(fname)
32+
assert f_ext == '.vtp', 'TubeFile must have .vtp file ending'
33+
34+
self.fname = fname
35+
self.G = G
1836

19-
Example:
20-
>> G = make_double_Y_bifurcation(dim=3)
21-
>> radius = Expression('x[1]+0.1*x[0]', degree=2)
22-
>> val = Expression('x[0]', degree=2)
23-
>> write_vtp(G, functions=[(radius, 'radius'), (val, 'val')])
2437

25-
The function values are assigned at each vertex.
26-
'''
27-
28-
# Store points in vtkPoints
29-
coords = G.global_mesh.coordinates()
30-
points = vtkPoints()
31-
for c in coords:
32-
points.InsertNextPoint(list(c))
38+
def __lshift__(self, func_and_radius):
39+
"""
40+
Write function to .vtp file
41+
42+
Args:
43+
func_and_radius (tuple):
44+
- func: function to plot
45+
- radius (function): network radius
46+
"""
47+
48+
func, radius = func_and_radius
49+
50+
assert self.G.geom_dim==3, f'Coordinates are {self.G.geom_dim}d, they need to be 3d'
51+
52+
# Store points in vtkPoints
53+
coords = self.G.global_mesh.coordinates()
54+
points = vtkPoints()
55+
for c in coords:
56+
points.InsertNextPoint(list(c))
3357

34-
# Store edges in cell array
35-
lines = vtkCellArray()
36-
edge_to_vertices = G.global_mesh.cells()
58+
# Store edges in cell array
59+
lines = vtkCellArray()
60+
edge_to_vertices = self.G.global_mesh.cells()
3761

38-
for vs in edge_to_vertices:
39-
line = vtkLine()
40-
line.GetPointIds().SetId(0, vs[0])
41-
line.GetPointIds().SetId(1, vs[1])
42-
lines.InsertNextCell(line)
62+
for vs in edge_to_vertices:
63+
line = vtkLine()
64+
line.GetPointIds().SetId(0, vs[0])
65+
line.GetPointIds().SetId(1, vs[1])
66+
lines.InsertNextCell(line)
4367

44-
# Create a polydata to store 1d mesh in
45-
linesPolyData = vtkPolyData()
46-
linesPolyData.SetPoints(points)
47-
linesPolyData.SetLines(lines)
68+
# Create a polydata to store 1d mesh in
69+
linesPolyData = vtkPolyData()
70+
linesPolyData.SetPoints(points)
71+
linesPolyData.SetLines(lines)
4872

4973

50-
# Add data from associated functions
51-
for func, name in functions:
74+
# Write data from associated function
5275
data = vtkDoubleArray()
53-
data.SetName(name)
76+
data.SetName(func.name())
5477
data.SetNumberOfComponents(1)
55-
56-
# Store value of function at each coordinates
78+
79+
# store value of function at each coordinates
5780
for c in coords:
5881
data.InsertNextTuple([func(c)])
5982

6083
linesPolyData.GetPointData().AddArray(data)
6184

85+
86+
# Write radius data
87+
data = vtkDoubleArray()
88+
data.SetName('radius')
89+
data.SetNumberOfComponents(1)
90+
91+
# store value of function at each coordinates
92+
for c in coords:
93+
data.InsertNextTuple([func(c)])
6294

63-
# Write to file
64-
writer = vtkXMLPolyDataWriter()
65-
writer.SetFileName(fname)
66-
writer.SetInputData(linesPolyData)
67-
writer.Update()
68-
writer.Write()
95+
linesPolyData.GetPointData().AddArray(data)
6996

97+
# Write to file
98+
writer = vtkXMLPolyDataWriter()
99+
writer.SetFileName(self.fname)
100+
writer.SetInputData(linesPolyData)
101+
writer.Update()
102+
writer.Write()

tests/test_vtp_plot.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
def test_vtp_plot():
1212

1313
G = make_double_Y_bifurcation(dim=3)
14-
G.make_mesh(3)
14+
V = FunctionSpace(G.global_mesh, 'CG', 1)
15+
v = Function(V)
16+
v.rename('u', '0.0')
17+
v.name()
1518

16-
f = Expression('x[1]+0.1*x[0]', degree=2)
17-
radius = Expression('x[1]+x[0]', degree=2)
18-
19-
write_vtp(G, functions=[(f, 'f'), (radius, 'radius')], fname='test.vtp')
19+
TubeFile(G, 'test.vtp') << (v, radius)
2020

2121
# Remove file again
2222
import os

0 commit comments

Comments
 (0)