Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower _conj_copy and alias operation. #8686

Merged
merged 7 commits into from
Mar 6, 2025
Merged

Conversation

ysiraichi
Copy link
Collaborator

Fix: #3070

This PR adds a lowering for _conj_copy. This operation is called by torch.conj, and was being executed using the fallback path. With this PR, torch.conj and its decomposed functions do not fallback.

@ysiraichi ysiraichi added the lowering ATen Operation lowering label Feb 6, 2025
@ysiraichi ysiraichi marked this pull request as ready for review February 6, 2025 16:38
@ysiraichi
Copy link
Collaborator Author

Update: I'm currently investigating this odd CI failure when functionalization is disabled.

  • It looks like I'm not able to get the XLATensorImpl instance of the input when cloning (inside ConjugateFallback.cpp)
    • i.e. there's a tensor whose device is XLA that doesn't hold a XLATensorImpl instance
  • Not sure why...
 Traceback (most recent call last):
  File "/__w/xla/xla/pytorch/xla/test/test_operations.py", line 2397, in test_conj_no_fallback
    self.assertEqual(actual, expected.cpu())
RuntimeError: torch_xla/csrc/aten_xla_bridge.cpp:110 : Check failed: xtensor 
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	torch_xla::bridge::GetXlaTensor(at::Tensor const&)
	torch_xla::XLANativeFunctions::clone(at::Tensor const&, std::optional<c10::MemoryFormat>)
	
	at::_ops::clone::call(at::Tensor const&, std::optional<c10::MemoryFormat>)
	
	
	at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
	
	
	at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
	at::native::to(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>)
	
	at::_ops::to_dtype_layout::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>)
	at::Tensor::to(c10::TensorOptions, bool, bool, std::optional<c10::MemoryFormat>) const
	...
*** End stack trace ***
Input tensor is not an XLA tensor: XLAComplexFloatType

@miladm miladm requested review from pgmoka and lsy323 February 11, 2025 19:02
@ysiraichi ysiraichi force-pushed the ysiraichi/lower-_conj_copy branch from 75f6b7d to 4488c47 Compare February 17, 2025 18:48
@ysiraichi
Copy link
Collaborator Author

Update: I have found the solution to the previous error. This PR should pass all tests, now

  • alias() operation was not lowered
    • Its default implementation instantiates a new TensorImpl (not `XLATensorImpl) using the same storage
    • This explained the non-XLA tensor with XLA device
  • The conjugate fallback wasn't working as expected
    • It was calling clone() in order to materialize the conjugate
      • Its default implementation called copy_, which did materialize the conjugate
    • clone() was a lowered operation, which didn't materialize the conjugate
      • Had to explicitly do so by creating a new IR node ConjCopy

@ysiraichi ysiraichi force-pushed the ysiraichi/lower-_conj_copy branch from 4488c47 to 4c61272 Compare February 18, 2025 12:18
@ysiraichi ysiraichi requested a review from tengyifei February 18, 2025 18:19
@ysiraichi
Copy link
Collaborator Author

Update: tests are all passing, now

@tengyifei @pgmoka @lsy323 Could you review it whenever you have some time?

Copy link
Collaborator

@pgmoka pgmoka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay on this review. I mainly have questions around tests.

@@ -2384,6 +2384,19 @@ def test_cummax_0_sized_dimension(self):

self.assertEqual(actual, expected)

def test_conj_no_fallback(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that this test is checking two behaviors:

  1. Is the lowered operation being used?
  2. Is the value of the lowered operation what is expected?

I think these two should be two separate tests for ease of debugging

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of adding an error message instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think two different tests is preferable. It is easier to debug unit tests when each case individually encompasses a singular behavior.

With #8725 as a future bug, I think it also makes sense to keep these tests separate as the operation lowering behavior tests might be refactored in the future. Having this be a separate test will provide for an easier example case to base the refactor on.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have 2 reasons to prefer only 1 test:

  1. Both tests would have almost the same code, except for the self.assertEquals(...). Duplicating it would make it harder to maintain.
  2. Although this test also tests for the actual output, it does so just to make sure the results are equal. The main thing being tested here is whether the lowerings are actually used. The tests for the actual operation can be found in test_ops.py.

Let me know if you still think I should split this test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Both tests would have almost the same code, except for the self.assertEquals(...). Duplicating it would make it harder to maintain.

I wouldn't be too worried about the code duplication here as we are looking at a small amount of code, and the meat of the code is in a helper method which could easily be removed from the method.

Although this test also tests for the actual output, it does so just to make sure the results are equal. The main thing being tested here is whether the lowerings are actually used. The tests for the actual operation can be found in test_ops.py.

I think this is a good point, and a reasonable justification for not splitting the tests. Since we already have a test, there is no reason to duplicate.

Overall, I think given that justification, adding error message here should be sufficient

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find a test file for this module. Should we create a bug for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good. But, in the end, I think this would already get tested in the Python operations above.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my opinion here would be based on test coverage. Some of these operations I think seem complex enough to warrant a small unit test. If we did something like paramatized tests could be a clean way to add tests that made additional tests easier to implement.

@ysiraichi ysiraichi changed the title Lower _conj_copy operation. Lower _conj_copy and alias operation. Feb 19, 2025
@ysiraichi
Copy link
Collaborator Author

This TPU CI failure looks unrelated to me. Let me undo my changes and see if the issue persists.

@@ -1229,8 +1229,12 @@ XLATensorPtr clamp(const XLATensorPtr& input,
Clamp(input->GetIrValue(), min_max.min, min_max.max));
}

XLATensorPtr clone(const XLATensorPtr& input) {
XLATensorPtr cloned = input->CreateFrom(input->GetIrValue());
XLATensorPtr clone(const XLATensorPtr& input, bool is_conj) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the clone method implementation depend on conj? This seems like a layering violation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because that's where conjugate is supposed to be materialized. PyTorch's conjugate is implemented by a dispatch-key, making it lazy. It's only reflected to the actual data on clone() calls.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Could we factor the ir = torch::lazy::Value(torch_xla::MakeNode<ConjCopy>(ir)); into a separate conj_copy function instead of adding a is_conj flag to Clone? I understand the need to match PyTorch semantics, but it would still be good to avoid mixing concerns.

@tengyifei
Copy link
Collaborator

I left some questions. Apologies if they don't make sense. I'm just trying to understand this PR.

@ysiraichi
Copy link
Collaborator Author

@tengyifei Thank you for your review. In fact, you saw lots of comments because I was trying to figure out whether the TPU CI failure was this PR's fault or not. I will ping you for another review once I have that figured out.

@ysiraichi ysiraichi force-pushed the ysiraichi/lower-_conj_copy branch from d1025b0 to 67c574d Compare February 25, 2025 18:31
@ysiraichi
Copy link
Collaborator Author

Ok. This flakiness issue was kind of frustrating. TPU CI seems to be okay with my changes, now.
@tengyifei Could you take a look at this PR whenever you have some time?

@ysiraichi
Copy link
Collaborator Author

I will rebase these changes and wait for CI once more.

@ysiraichi
Copy link
Collaborator Author

I'm pretty sure this error is unrelated to this PR.

@tengyifei
Copy link
Collaborator

that's my bad: #8747

@ysiraichi
Copy link
Collaborator Author

No problem. But I think that this only highlights more #8745.

@ysiraichi ysiraichi force-pushed the ysiraichi/lower-_conj_copy branch from 916112d to 0adc64f Compare February 27, 2025 12:22
@ysiraichi
Copy link
Collaborator Author

CI is green.
@tengyifei @pgmoka Could you review this PR whenever you have some time?

Copy link
Collaborator

@pgmoka pgmoka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM minus final comments. Once addressed, feel free to @ me again or CC me for a final look. I would also be curious for an LGTM for one of the other folks in the PR.

@ysiraichi
Copy link
Collaborator Author

ysiraichi commented Feb 28, 2025

@tengyifei @pgmoka If you have no more change requests/observations, could you stamp this CI, please?

@@ -1229,8 +1229,12 @@ XLATensorPtr clamp(const XLATensorPtr& input,
Clamp(input->GetIrValue(), min_max.min, min_max.max));
}

XLATensorPtr clone(const XLATensorPtr& input) {
XLATensorPtr cloned = input->CreateFrom(input->GetIrValue());
XLATensorPtr clone(const XLATensorPtr& input, bool is_conj) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Could we factor the ir = torch::lazy::Value(torch_xla::MakeNode<ConjCopy>(ir)); into a separate conj_copy function instead of adding a is_conj flag to Clone? I understand the need to match PyTorch semantics, but it would still be good to avoid mixing concerns.

Copy link
Collaborator

@pgmoka pgmoka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please address other comments before submitting

Copy link
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry we have some flaky tests again

Copy link
Collaborator

@pgmoka pgmoka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved

@tengyifei tengyifei merged commit 6c53a1e into master Mar 6, 2025
21 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
lowering ATen Operation lowering
Projects
None yet
Development

Successfully merging this pull request may close these issues.

lower complex number operations (view_as_real, view_as_complex, conj, abs)
3 participants