forked from NVIDIA/physicsnemo
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvtk_tools.py
More file actions
149 lines (121 loc) · 4.14 KB
/
vtk_tools.py
File metadata and controls
149 lines (121 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# ignore_header_test
# Copyright 2023 Stanford University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import vtk
from vtk.util.numpy_support import vtk_to_numpy as v2n
import numpy as np
def read_geo(fname):
"""
Read geometry from file.
Arguments:
fname: File name
Returns:
The vtk reader
"""
_, ext = os.path.splitext(fname)
if ext == ".vtp":
reader = vtk.vtkXMLPolyDataReader()
elif ext == ".vtu":
reader = vtk.vtkXMLUnstructuredGridReader()
else:
raise ValueError("File extension " + ext + " unknown.")
reader.SetFileName(fname)
reader.Update()
return reader
def get_all_arrays(geo, components=None):
"""
Get arrays from geometry file.
Arguments:
geo: Input geometry
components (int): Number of array components to keep.
Default: None -> keep all
Returns:
Point data dictionary (key: array name, value: numpy array)
Cell data dictionary (key: array name, value: numpy array)
Points (numpy array)
"""
# collect all arrays
cell_data = collect_arrays(geo.GetCellData(), components)
point_data = collect_arrays(geo.GetPointData(), components)
points = collect_points(geo.GetPoints(), components)
return point_data, cell_data, points
def get_edges(geo):
"""
Get edges from geometry file.
Arguments:
geo: Input geometry
Returns:
List of nodes indices (first nodes in each edge)
List of nodes indices (second nodes in each edge)
"""
edges1 = []
edges2 = []
ncells = geo.GetNumberOfCells()
for i in range(ncells):
edges1.append(int(geo.GetCell(i).GetPointIds().GetId(0)))
edges2.append(int(geo.GetCell(i).GetPointIds().GetId(1)))
return np.array(edges1), np.array(edges2)
def collect_arrays(celldata, components=None):
"""
Collect arrays from a cell data or point data object.
Arguments:
celldata: Input data
components (int): Number of array components to keep.
Default: None -> keep all
Returns:
A dictionary of arrays (key: array name, value: numpy array)
"""
res = {}
for i in range(celldata.GetNumberOfArrays()):
name = celldata.GetArrayName(i)
data = celldata.GetArray(i)
if components == None:
res[name] = v2n(data).astype(np.float32)
else:
res[name] = v2n(data)[:components].astype(np.float32)
return res
def collect_points(celldata, components=None):
"""
Collect points from a cell data object.
Arguments:
celldata: Name of the directory
components (int): Number of array components to keep.
Default: None -> keep allNone
Returns:
The array of points (numpy array)
"""
if components == None:
res = v2n(celldata.GetData()).astype(np.float32)
else:
res = v2n(celldata.GetData())[:components].astype(np.float32)
return res
def gather_array(arrays, arrayname, mintime=1e-12):
"""
Given a dictionary of numpy arrays, this method gathers all the arrays
containing a certain substring in the array name.
Arguments:
arrays: Arrays look into.
arrayname (string): Substring to look for.
mintime (float): Minimum time to consider. Default value = 1e-12.
Returns:
Dictionary of arrays (key: time, value: numpy array)
"""
out = {}
for array in arrays:
if arrayname in array:
time = float(array.replace(arrayname + "_", ""))
if time > mintime:
out[time] = arrays[array]
return out