Open
Description
Is your feature request related to a problem? Please describe.
Would be good to remove the megatron tensor parallelism code from NeoX, and OSLO currently has support for this, and a slightly nicer interface.
Describe the solution you'd like
Steps:
- Rewrite all current modules as plain pytorch implementations, removing the
mpu
dependency from any internal code as much as possible. (so, anything that's currently anmpu.[Column|Row]ParallelLinear
ormpu.VocabParallelEmbedding
should be replaced with its plain pytorch equivalent (nn.Linear
/nn.Embedding
respectively). - Write a mapping for neox modules, which oslo uses to handle parallelization.
- Ensure backwards compatibility