Skip to content

Drop torch.jit.script from SelectOutput and ConnectOutput#10701

Open
V-3604 wants to merge 1 commit into
pyg-team:masterfrom
V-3604:fix/jit-script-pool-outputs
Open

Drop torch.jit.script from SelectOutput and ConnectOutput#10701
V-3604 wants to merge 1 commit into
pyg-team:masterfrom
V-3604:fix/jit-script-pool-outputs

Conversation

@V-3604
Copy link
Copy Markdown

@V-3604 V-3604 commented May 19, 2026

Fixes #10697.

torch.jit.script is deprecated in PyTorch 2.12 and emits a DeprecationWarning whether its called as a function or used as a decorator. The module level calls:

SelectOutput = torch.jit.script(SelectOutput)
ConnectOutput = torch.jit.script(ConnectOutput)

in select/base.py and connect/base.py therefore raise the warning at import time, which pyproject.toml''s error:.*torch_geometric.* filter treats as a test error. There is a workaround in place (ignore:.*torch.jit.* further down in the filter list) but the deprecated call itself is still there.

This drops the scripting on both classes. They are now plain Python classes with type annotations and a custom __init__. @dataclass(init=False) is removed because it isn''t needed once we are not scripting the result (and the synthesized __repr__ actually trips up TorchScript when downstream modules that return these types are scripted).

Verified on torch==2.12.0:

  • import is clean (-W error::DeprecationWarning passes)
  • test/nn/pool/select/test_select_topk.py and test/nn/pool/connect/test_filter_edges.py both pass (the is_full_test() JIT branches are not exercised by default, but I separately confirmed torch.jit.script(FilterEdges()) still scripts correctly with the refactored ConnectOutput)
  • broader test/nn/pool/ suite: 28 passed, 16 skipped (unrelated optional deps)

The ignore:.*torch.jit.*:DeprecationWarning filter in pyproject.toml is left in place since the remaining torch.jit.script(module) calls in the is_full_test() paths still emit the same warning.

Screenshot:
Untitled

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DeprecationWarning: torch.jit.script is deprecated. Please switch to torch.compile or torch.export.

1 participant