This can be implemented directly or with support from PyTorch lightning. How should we go about this?