-
Notifications
You must be signed in to change notification settings - Fork 25
Register Random OPs, eager mode, forward pass #1227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| m.impl("bernoulli", TORCH_FN(tt_eager::ext::unary_random_seeded<ttnn::bernoulli>::invoke)); | ||
| // schema: bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) | ||
| m.impl("bernoulli.out", TORCH_FN(tt_eager::ext::unary_random_seeded<ttnn::bernoulli>::invoke_into)); | ||
| // bernoulli_.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we separate out the ones that aren't implemented from the ones that are (I think you did this in one of the other PRs). It makes it clearer where the gaps are
| #include <ttnn/operations/rand/rand.hpp> | ||
| #include <ttnn/operations/bernoulli/bernoulli.hpp> | ||
| #include <ttnn/operations/uniform/uniform.hpp> | ||
| #include <ttnn/operations/eltwise/unary/unary.hpp> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed this in some of the other PRs too - is there a reason we include unary.hpp in all of the eager wrapper files?
| [[nodiscard]] static at::Tensor& invoke_into( | ||
| const at::Tensor& input, c10::optional<at::Generator> generator, at::Tensor& out) { | ||
| ttnn::Tensor in_tile = tt_eager::ext::tileify(input); | ||
| static thread_local std::mt19937 rng(std::random_device{}()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the random number generator used by PyTorch / tt-metal?
| *device, | ||
| ttnn::DataType::FLOAT32, | ||
| layout, | ||
| ttnn::DRAM_MEMORY_CONFIG, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DRAM is good for first cut, but there are big perf implications of DRAM vs. SRAM (L1). We should explore how difficult it is to support SRAM in the future (for all OPs, not just randoms)
| } | ||
| static inline ttnn::Tensor cast_after_sampling(const ttnn::Tensor& src, at::ScalarType st, bool is_int) { | ||
| if (is_int) { | ||
| auto floored = ttnn::floor(src); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: names like floored don't make it clear where the original value came from or what is represents
Registering Random TTNN OPs (PR on top of #1210)
Problem description
Existing TTNN OPs should be registered to Pytorch Dispatcher.
The scope of this PR is Random OPs only, other groups of OPs are planned to be added in other PRs
What's changed
Random wrappers are introduced, Random TTNN OPs are registered
TODO: