Skip to content

Commit d1d65ee

Browse files
committed
update corpus test and print a summary
1 parent 1863569 commit d1d65ee

File tree

8 files changed

+148
-55
lines changed

8 files changed

+148
-55
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = ["setuptools >= 61.0", "wheel"]
55
[project]
66
name = "trimesh"
77
requires-python = ">=3.8"
8-
version = "4.9.0"
8+
version = "4.10.0"
99
authors = [{name = "Michael Dawson-Haggerty", email = "[email protected]"}]
1010
license = {file = "LICENSE.md"}
1111
description = "Import, export, process, analyze and view triangular meshes."
@@ -168,7 +168,7 @@ flake8-implicit-str-concat = {"allow-multiline" = false}
168168
# disallow things that have caused problems
169169
[tool.ruff.lint.flake8-tidy-imports.banned-api]
170170
"IPython.embed".msg = "you forgot to remove a debug embed ;)"
171-
# "matplotlib".msg = "you forgot to remove a debug plot"
171+
"matplotlib".msg = "you forgot to remove a debug plot"
172172
"numpy.empty".msg = "uninitialized arrays are haunted: try numpy.zeros"
173173
"numpy.empty_like".msg = "uninitialized arrays are haunted: try numpy.zeros"
174174

tests/corpus.py

Lines changed: 137 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
will download more than a gigabyte to your home directory!
77
"""
88

9+
import argparse
910
import json
1011
import sys
1112
import time
13+
from collections import defaultdict
1214
from dataclasses import asdict, dataclass
1315

1416
import numpy as np
@@ -28,6 +30,12 @@ class LoadReport:
2830
# i.e 'glb'
2931
file_type: str
3032

33+
# how long did this take
34+
duration: float
35+
36+
# how many bytes was this file?
37+
file_size: float
38+
3139
# i.e. 'Scene'
3240
type_load: Optional[str] = None
3341

@@ -53,34 +61,112 @@ class Report:
5361
# a pyinstrument.renderers.JSONRenderer output
5462
profile: str
5563

56-
def compare(self, other: "Report"):
64+
def summary(self) -> str:
5765
"""
58-
Compare this load report to another.
66+
Prints a nice markdown table of load results including both overall
67+
and per-format statistics.
5968
"""
60-
# what files were loaded by both versions
61-
self_type = {o.file_name: o.type_load for o in self.load}
62-
other_type = {n.file_name: n.type_load for n in other.load}
69+
# Group loads by file type and split success/failure
70+
by_type = defaultdict(list)
71+
for load in self.load:
72+
by_type[load.file_type].append(load)
73+
74+
# Count exceptions per type
75+
exc_counts = defaultdict(lambda: defaultdict(int))
76+
for load in self.load:
77+
if load.exception:
78+
exc_counts[load.file_type][load.exception] += 1
79+
80+
# Extract successful load metrics for overall stats
81+
successful = [load for load in self.load if load.exception is None]
82+
duration = np.array([load.duration for load in successful])
83+
size = np.array([load.file_size for load in successful])
84+
85+
lines = []
86+
87+
# Build exception table if there are any exceptions
88+
if exc_counts:
89+
lines.append("Exceptions\n=================\n")
90+
exc_rows = [
91+
(ftype.upper(), count, str(exc)[:70])
92+
for ftype in sorted(exc_counts.keys())
93+
for exc, count in exc_counts[ftype].items()
94+
]
95+
lines.append(markdown_table(("Format", "Count", "Exception"), exc_rows))
96+
lines.append("")
97+
98+
# Build main results table
99+
rows = []
100+
# Add overall row
101+
success = float(len(successful)) / len(self.load) if len(self.load) > 0 else 0.0
102+
rows.append(
103+
(
104+
"Overall",
105+
f"{len(successful)}/{len(self.load)} ({success * 100.0:.2f}%)",
106+
f"{duration.mean():.3f} ± {duration.std():.3f}",
107+
f"{size.mean() / 1e6:.2f} ± {size.std() / 1e6:.2f}",
108+
)
109+
)
110+
111+
# Add per-format rows
112+
for ftype in sorted(by_type.keys()):
113+
loads = by_type[ftype]
114+
ok = [load for load in loads if load.exception is None]
115+
if ok:
116+
dur = np.array([load.duration for load in ok])
117+
sz = np.array([load.file_size for load in ok])
118+
success = float(len(ok)) / len(loads) if len(loads) > 0 else 0.0
119+
rows.append(
120+
(
121+
ftype.upper(),
122+
f"{len(ok)}/{len(loads)} ({success * 100.0:.2f}%)",
123+
f"{dur.mean():.3f} ± {dur.std():.3f}",
124+
f"{sz.mean() / 1e6:.2f} ± {sz.std() / 1e6:.2f}",
125+
)
126+
)
63127

64-
both = set(self_type.keys()).intersection(other_type.keys())
65-
matches = np.array([self_type[k] == other_type[k] for k in both])
66-
percent = matches.sum() / len(matches)
128+
lines.append("\nLoad Results\n=================\n")
129+
lines.append(markdown_table(("Format", "Loaded", "Time (s)", "Size (MB)"), rows))
67130

68-
print(f"Comparing `{self.version}` against `{other.version}`")
69-
print(f"Return types matched {percent * 100.0:0.3f}% of the time")
70-
print(f"Loaded {len(self.load)} vs Loaded {len(other.load)}")
131+
return "\n".join(lines)
71132

72133

73-
def from_dict(data: dict) -> Report:
134+
def markdown_table(headers: tuple[str, ...], rows: list[tuple]) -> str:
74135
"""
75-
Parse a `Report` which has been exported using `dataclasses.asdict`
76-
into a Report object.
136+
Print a markdown-formatted table.
137+
138+
Parameters
139+
----------
140+
headers
141+
Column headers as a tuple of strings.
142+
rows
143+
List of tuples, where each tuple represents a row of data.
144+
145+
Returns
146+
-------
147+
table
148+
A string containing the markdown-formatted table.
77149
"""
78-
return Report(
79-
load=[LoadReport(**r) for r in data.get("load", [])],
80-
version=data.get("version"),
81-
profile=data.get("profile"),
150+
# set column widths based on the longest item in each column
151+
col_widths = [
152+
max(len(h), max(len(str(row[i])) for row in rows)) for i, h in enumerate(headers)
153+
]
154+
# start with header row and separator row
155+
lines = [
156+
"| " + " | ".join(h.ljust(col_widths[i]) for i, h in enumerate(headers)) + " |",
157+
"| " + " | ".join("-" * w for w in col_widths) + " |",
158+
]
159+
160+
# extend with data rows
161+
lines.extend(
162+
"| "
163+
+ " | ".join(str(cell).ljust(col_widths[i]) for i, cell in enumerate(row))
164+
+ " |"
165+
for row in rows
82166
)
83167

168+
return "\n".join(lines)
169+
84170

85171
def on_repo(
86172
repo: str, commit: str, available: set, root: Optional[str] = None
@@ -123,33 +209,39 @@ def on_repo(
123209
should_raise = any(b in check for b in broke)
124210
raised = False
125211

212+
blob = resolver.get(name)
213+
126214
# start collecting data about the current load attempt
127-
current = LoadReport(file_name=name, file_type=trimesh.util.split_extension(name))
215+
current = LoadReport(
216+
file_name=name,
217+
duration=0.0,
218+
file_size=len(blob),
219+
file_type=trimesh.util.split_extension(name),
220+
)
128221

129222
print(f"Attempting: {name}")
130223

131224
try:
132-
m = trimesh.load(
225+
tic = time.time()
226+
m = trimesh.load_scene(
133227
file_obj=wrap_as_stream(resolver.get(name)),
134228
file_type=name,
135229
resolver=resolver,
136230
)
137-
138-
# save the load types
139-
current.type_load = m.__class__.__name__
140-
if isinstance(m, trimesh.Scene):
141-
# save geometry types
142-
current.type_geometry = tuple(
143-
[g.__class__.__name__ for g in m.geometry.values()]
144-
)
231+
toc = time.time()
232+
# save geometry types
233+
current.type_geometry = tuple(
234+
{g.__class__.__name__ for g in m.geometry.values()}
235+
)
236+
current.duration = toc - tic
145237
# save the <Trimesh ...> repr
146238
current.repr_load = str(m)
147239

148240
# if our source was a GLTF we should be able to roundtrip without
149241
# dropping
150242
if name.lower().split(".")[-1] in ("gltf", "glb") and len(m.geometry) > 0:
151243
# try round-tripping the file
152-
e = trimesh.load(
244+
e = trimesh.load_scene(
153245
file_obj=wrap_as_stream(m.export(file_type="glb")),
154246
file_type="glb",
155247
process=False,
@@ -236,7 +328,7 @@ def equal(a, b):
236328
return a == b
237329

238330

239-
def run(save: bool = False):
331+
def run(save: bool = False) -> Report:
240332
"""
241333
Try to load and export every mesh we can get our hands on.
242334
@@ -257,14 +349,11 @@ def run(save: bool = False):
257349
]
258350
)
259351

260-
# TODO : waiting on a release containing pycollada/pycollada/147
261-
available.difference_update({"dae"})
262-
263352
with Profiler() as P:
264353
# check against the small trimesh corpus
265354
loads = on_repo(
266355
repo="mikedh/trimesh",
267-
commit="2fcb2b2ea8085d253e692ecd4f71b8f450890d51",
356+
commit="76b6dd1a2f552673b3b38ffd44ce4342d4e95273",
268357
available=available,
269358
root="models",
270359
)
@@ -273,7 +362,7 @@ def run(save: bool = False):
273362
loads.extend(
274363
on_repo(
275364
repo="assimp/assimp",
276-
commit="1e44036c363f64d57e9f799beb9f06d4d3389a87",
365+
commit="ab28db52f022a7268ffff499cd85bbabf84c4271",
277366
available=available,
278367
root="test",
279368
)
@@ -312,16 +401,19 @@ def run(save: bool = False):
312401

313402

314403
if __name__ == "__main__":
315-
trimesh.util.attach_to_log()
316-
317-
if "-run" in " ".join(sys.argv):
318-
run()
404+
parser = argparse.ArgumentParser(
405+
description="Test trimesh loaders against large corpuses (downloads >1GB to ~/.trimesh-cache)"
406+
)
407+
parser.add_argument("-run", action="store_true", help="Run the corpus test")
408+
parser.add_argument("-save", action="store_true", help="Save JSON report")
319409

320-
if "-compare" in " ".join(sys.argv):
321-
with open("trimesh.4.5.3.1737061410.json") as f:
322-
old = from_dict(json.load(f))
410+
args = parser.parse_args()
411+
if len(sys.argv) == 1:
412+
parser.print_help()
413+
sys.exit(0)
323414

324-
with open("trimesh.4.6.0.1737060030.json") as f:
325-
new = from_dict(json.load(f))
415+
trimesh.util.attach_to_log()
326416

327-
new.compare(old)
417+
if args.run:
418+
report = run(save=args.save)
419+
print(report.summary())

tests/test_color.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def test_interpolate():
365365

366366
# now see if we match matplotlib if it's installed
367367
try:
368-
from matplotlib.pyplot import get_cmap
368+
from matplotlib.pyplot import get_cmap # noqa
369369
except ImportError:
370370
return
371371

trimesh/path/entities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ def plot(self, vertices, show=False):
436436
if vertices.shape[1] != 2:
437437
raise ValueError("only for 2D points!")
438438

439-
import matplotlib.pyplot as plt
439+
import matplotlib.pyplot as plt # noqa
440440

441441
# get rotation angle in degrees
442442
angle = np.degrees(self.angle(vertices))

trimesh/path/path.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,7 +1437,7 @@ def plot_discrete(self, show=False, annotations=True):
14371437
"""
14381438
Plot the closed curves of the path.
14391439
"""
1440-
import matplotlib.pyplot as plt
1440+
import matplotlib.pyplot as plt # noqa
14411441

14421442
axis = plt.gca()
14431443
axis.set_aspect("equal", "datalim")
@@ -1469,7 +1469,7 @@ def plot_entities(self, show=False, annotations=True, color=None):
14691469
color : str
14701470
Override entity colors and make them all this color.
14711471
"""
1472-
import matplotlib.pyplot as plt
1472+
import matplotlib.pyplot as plt # noqa
14731473

14741474
# keep plot axis scaled the same
14751475
axis = plt.gca()

trimesh/path/polygons.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def plot(polygon=None, show=True, axes=None, **kwargs):
294294
**kwargs
295295
Passed to plt.plot
296296
"""
297-
import matplotlib.pyplot as plt
297+
import matplotlib.pyplot as plt # noqa
298298

299299
def plot_single(single):
300300
axes.plot(*single.exterior.xy, **kwargs)

trimesh/points.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,9 @@ def plot_points(points, show=True):
374374
show : bool
375375
If False, will not show until plt.show() is called
376376
"""
377-
import matplotlib.pyplot as plt
378-
from mpl_toolkits.mplot3d import Axes3D # NOQA
377+
# TODO : should this just use SceneViewer?
378+
import matplotlib.pyplot as plt # noqa
379+
from mpl_toolkits.mplot3d import Axes3D # noqa
379380

380381
points = np.asanyarray(points, dtype=float64)
381382

trimesh/scene/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def show(self, **kwargs):
414414
kwargs : dict
415415
Passed to `networkx.draw_networkx`
416416
"""
417-
import matplotlib.pyplot as plt
417+
import matplotlib.pyplot as plt # noqa
418418
import networkx
419419

420420
# default kwargs will only be set if not

0 commit comments

Comments
 (0)