Skip to content

Add an API to change the datatype of an initialized module #627

@pranavm-nvidia

Description

@pranavm-nvidia

We should have some API to change the dtype of a module after it has been initialized.
Without such an API, the user has to propagate dtype through the module at construction time, which is cumbersome.

Possible APIs:

module = MyModule()

module.to(tp.float16)

module.cast(tp.float16)

module = tp.cast(module, tp.float16)

Metadata

Metadata

Assignees

No one assigned

    Labels

    tripyPull request for the tripy project

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions