Skip to content

Graph breaks in .forward() for PMECalculator #220

@hejamu

Description

@hejamu

With #219 one can now torch.compile() the whole PME calculations. This already yields a nice speedup for large systems:

PMECalculator, RTX 5080, fp32, bulk-water density (~100 atoms/nm^3)

N eager jit.script torch.compile torch.compile with reduce-overhead
1000 2.08 ms 1.13 ms 1.33 ms 0.89 ms
4000 2.22 ms 1.27 ms 1.52 ms 1.08 ms
16000 2.52 ms 1.65 ms 1.51 ms 1.07 ms

There is however a large constant caused by graph breaks that reduce-overhead cannot fuse into one CUDA graph because CPU syncs are required.

This is caused by the mesh_interpolator mutating the state in mesh_interpolator.update()

It can be fixed relatively easy by computing the required mesh parameters on the fly and not writing them to self first. On the other hand this is a slight refactor with possible API changes.

I tested this quick and dirty and the numbers look worth the effort:

(Same config as table above)

Code N eager compile(default) reduce-overhead (CUDA Graph)
PR #219 1000 2.06 ms 1.33 ms 0.89 ms
4000 2.20 ms 1.51 ms 1.06 ms
16000 2.54 ms 1.52 ms 1.06 ms
no state mutation 1000 1.82 ms 0.56 ms 0.16 ms
4000 1.84 ms 0.58 ms 0.23 ms
16000 1.84 ms 0.57 ms 0.19 ms

I can clean up my changes and open a PR, if this is interesting for others.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions