Refactor ProgramNode::call to flatten+call_flat+unflatten#16298
Conversation
|
The diff on |
Coverage Report for CI Build 26904388576Warning Build has drifted: This PR's base is out of sync with its target branch, so coverage data may include unrelated changes. Coverage decreased (-0.05%) to 87.482%Details
Uncovered Changes
Coverage Regressions576 previously-covered lines in 13 files lost coverage.
Coverage Stats
💛 - Coveralls |
…elpers Split the ProgramNode trait so implementations only define `call_flat(args: &[Tensor]) -> Result<Vec<Tensor>, _>` while a blanket ProgramNodeExt provides the DataTree-shaped `call(args: &DataTree<Tensor>)` on top. Adds the supporting DataTree machinery (`flatten_against`, `unflatten`, `map_leaves`, `into_leaves`, `TreeMatchError`, `ArityMismatch`) and structured call errors (`CallInputError`, `CallError`, `MissingCallError`). Store is rewritten on top of `call_flat` and stores its tensors as a flat `Vec<Tensor>` aligned with `output_types`.
|
One or more of the following people are relevant to this code:
|
mtreinish
left a comment
There was a problem hiding this comment.
I've made it through the data tree changes, but I'm out of time for today to review the rest of of the changes so I'll leave the comments I have so far. Besides the inline comments I'm wondering if we should split this PR up a bit to make it easier to review a lot of the data tree changes in isolation, I think you left a comment this already. I guess it's kind of moot now since I made it through most of the data tree changes already though so probably not worth changing at this point.
I think my biggest concern right now is the change to the iterators to basically collect everything into a vec up front and then iterate over that vec. That kind of eliminates the benefits of using an iterator. If we need that new form for some reason I'd say we should just return the Vec directly and let the user iterate over it. That'll be more explicit about how they work and also give the caller more control on how to use the result.
| iter_paths_inner(self, &mut Vec::new(), &mut out); | ||
| out.into_iter() |
There was a problem hiding this comment.
I'm not sure how I feel about this change and the other iterators. The previous method was actually an iterator, it was building the output as it was going through the data tree. This changes it so it's not really working as an iterator, it's completely iterating over the entire DataTree building a vec of the expected iterator output and then iterating over that vec as the output.
There was a problem hiding this comment.
I'm terribly sorry, but I moved this out of draft mode thinking that I had successfully pushed a change that reverts many unnecessary changes to this file, including this one. The version you have reviewed makes too many out-of-scope changes.
There was a problem hiding this comment.
Please excuse the force-push-after-review, but it just seemed cleaner as the correction commit would have been pretty messy. In the new diff, you can see that this PR now touches significantly less stuff. Except in one or two small spots, it only touches things it needs to.
c4f4b64 to
9e03221
Compare
| /// assert_eq!(leaves, expected); | ||
| /// ``` | ||
| pub fn iter_leaves(&self) -> IterLeaves<'_, T> { | ||
| pub fn iter_leaves(&self) -> impl Iterator<Item = &T> { |
There was a problem hiding this comment.
I can revert if you'd like this in a different PR, it's not directly related to this one, but I thought it would be nice to take the iterator classes out of the public API of this class so that we're free to update them as needed.
There was a problem hiding this comment.
This is fine, I don't think it matters a ton either way. The rust APIs are all private since we don't publish them so we're free to change this at any point. But this is probably easier for people to read when working in the rust code.
| ) -> Result<DataTree<Tensor>, CallError<Self::CallError>> { | ||
| let flat = self | ||
| .input_types() | ||
| .flatten_against(args) |
There was a problem hiding this comment.
Are you worried about the copy this requires for every call? It's not necessarily an issue to start I guess, but my concern is that we'll paint ourselves into a corner in the public interface if we're requiring a flattened input in python and c impl's call_flat() and we can't avoid a copy.
I just worry how large these data trees could get and whether we really need to be copying it all for call, especially as we're going to be passing a slice to call_flat here so the implementor doesn't need an owned copy necessarily.
There was a problem hiding this comment.
I should lead with two things that affect the calculus here quite a bit, but have not been merged in yet as they appear later in the PR stack:
QuantumProgramdoes not callcall(), it callscall_flat(), when evaluating the graph. Indeed, the main point of this refactor is to giveQuantumProgramthat ability. This means thatflatten_againstis only invoked once on entry, and so the relevant question becomes "how many inputs will aQuantumProgramhave?" instead of "how many times will a node be called"? My expectation is usually 0, and typically <100, but I can't predict the future.- The next PR switches from
TensortoArcTensor, so cloning the list doesn't copy the tensor data, it just increments some counter.
Now given that, we can still switch to call_flat accepting &[&Tensor] (and whatever that implies for flatten_against. Would you like to do that? It doesn't seem like too much work.
There was a problem hiding this comment.
I had overlooked the next PR in the chain with Arc so my concern about cloning here is mitigated quite a bit by that. We'll still be allocating a vec of n pointers for each call to call() but that's not terrible. So I'm fine with this given the follow on PRs.
mtreinish
left a comment
There was a problem hiding this comment.
This LGTM now thanks for all the updates on this.
| loop { | ||
| let top = self.stack.last_mut()?; | ||
| match top.next() { | ||
| None => { | ||
| self.stack.pop(); | ||
| } | ||
| Some(DataTree::Leaf(v)) => return Some(v), | ||
| Some(DataTree::Branch(b)) => self.stack.push(b.data.into_iter()), | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
I was thinking it'd be better to do this with a while loop, but testing this locally you either drop the ? from last_mut() and then add an explicit None return at the end of the function. Or you end up with while let top = self.stack.last_mut() which I assume you tried because clippy told me I should use a loop { ... } with a let inside instead. So I'm fine with this.
This PR split the
ProgramNodetrait so implementations only definecall_flat(args: &[Tensor]) -> Result<Vec<Tensor>, _>while a newProgramNodeExttrait provides the DataTree-I/Ocall(args: &DataTree<Tensor>)on top (which uses the blanket impl of ProgramNode trick to disallow specializations). This comes with three types of enum error classes: (CallInputError,CallError,MissingCallError).Since
Storeis the only implementation ofProgramNodethat's been merged into main so far, it's the only one that needs to change. However, you can (and should) especially check outQuantumProgramhigher in the PR stack to get a sense for howcall_flatis actually used:QuantumProgramuses it directly, and gets to totally avoid dealing with hash maps because it can reason about all arguments positionally.The above
ProgramNoderefactor is itself straight forward. However, this PR also adds/updates a bunch of supportingDataTreemachinery.PR Stack
AI/LLM disclosure