Skip to content

Conversation

@Adityakushwaha2006
Copy link
Contributor

@Adityakushwaha2006 Adityakushwaha2006 commented Dec 26, 2025

Reference Issues/PRs

Fixes #1542
Already supports the backend change (CuPy based ) brought about in #3211

What does this implement/fix? Explain your changes.

This PR implements the GPU integration interface for RocketClassifier and RocketRegressor. It introduces a use_gpu=True parameter that seamlessly switches the underlying transformer to the GPU implementation when enabled.

Key Implementation Details:

  • Zero-Config GPU: Users simply pass use_gpu=True.
  • Soft Dependency Handling: Uses _check_soft_dependencies("tensorflow") to ensure a clean error message if the user enables GPU mode without the backend installed.
  • Backend Decoupled: The implementation uses a dynamic import pattern. This ensures the Classifier/Regressor logic remains backend-agnostic (allowing for future swaps to CuPy/Torch without breaking the API).
  • Backward Compatibility: Default behavior (use_gpu=False) is strictly preserved and uses the standard CPU transformer.

Example Usage:

# Standard CPU usage (Unchanged)
clf = RocketClassifier(n_kernels=500)

# New GPU usage (Requires TensorFlow backend currently)
clf_gpu = RocketClassifier(n_kernels=500, use_gpu=True)
# clf_gpu.fit(X_train, y_train)

Does your contribution introduce a new dependency? If yes, which one?

No new mandatory dependencies. It adds tensorflow as a soft dependency (optional). The code only attempts to import it if use_gpu=True is explicitly set.

Any other comments?

Context: I am picking this up following my check-in on Issue #1542 (assignee inactive since Aug 2024) to unblock the GPU integration roadmap requested by @hadifawaz1999.

Verification: Verified locally with the following matrix:

  • use_gpu=False (Default): Matches CPU baseline exactly.
  • use_gpu=True: Successfully initializes GPU transformer pipeline.

Missing Dependency Check: Verified that setting use_gpu=True without TensorFlow raises a clear ImportError rather than a crash.

PR checklist

For all contributions
[x] I've added myself to the list of contributors. alternatively, you can use the @all-contributors bot to do this for you after the PR has been merged.

[x] The PR title starts with either [ENH], [MNT], [DOC], [BUG], [REF], [DEP] or [GOV] indicating whether the PR topic is related to enhancement, maintenance, documentation, bugs, refactoring, deprecation or governance.

For new estimators and functions
[ ] I've added the estimator/function to the online API documentation.

[ ] (OPTIONAL) I've added myself as a maintainer at the top of relevant files and want to be contacted regarding its maintenance. Unmaintained files may be removed. This is for the full file, and you should not add yourself if you are just making minor changes or do not want to help maintain its contents.

For developers with write access
[ ] (OPTIONAL) I've updated aeon's CODEOWNERS to receive notifications about future changes to these files.

… handling

- Add use_gpu parameter to enable GPU acceleration
- Implement soft dependency checks for TensorFlow
- Add try-except with helpful error messages for missing dependencies
- Standardize parameter order across Classifier/Regressor
- Add type hints for API consistency
- Maintain 100% backward compatibility (default: CPU mode)

Code review improvements:
- Fixed critical import trap (ModuleNotFoundError -> helpful ImportError)
- Integrated aeon's _check_soft_dependencies validation system
- Aligned parameter order between Classifier and Regressor
- Added type hints to RocketRegressor for consistency

Resolves aeon-toolkit#1542
@aeon-actions-bot aeon-actions-bot bot added classification Classification package enhancement New feature, improvement request or other non-bug code enhancement regression Regression package labels Dec 26, 2025
@aeon-actions-bot
Copy link
Contributor

Thank you for contributing to aeon

I have added the following labels to this PR based on the title: [ enhancement ].
I have added the following labels to this PR based on the changes made: [ classification, regression ]. Feel free to change these if they do not properly represent the PR.

The Checks tab will show the status of our automated tests. You can click on individual test runs in the tab or "Details" in the panel below to see more information if there is a failure.

If our pre-commit code quality check fails, any trivial fixes will automatically be pushed to your PR unless it is a draft.

Don't hesitate to ask questions on the aeon Slack channel if you have any.

PR CI actions

These checkboxes will add labels to enable/disable CI functionality for this PR. This may not take effect immediately, and a new commit may be required to run the new configuration.

  • Run pre-commit checks for all files
  • Run mypy typecheck tests
  • Run all pytest tests and configurations
  • Run all notebook example tests
  • Run numba-disabled codecov tests
  • Stop automatic pre-commit fixes (always disabled for drafts)
  • Disable numba cache loading
  • Regenerate expected results for testing
  • Push an empty commit to re-run CI checks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

classification Classification package enhancement New feature, improvement request or other non-bug code enhancement regression Regression package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ENH] Adding Rocket GPU to rocket classifier/regressor

1 participant