Skip to content

Commit 257a10b

Browse files
committed
Lower _conj_copy operation.
1 parent 998ea37 commit 257a10b

File tree

4 files changed

+13
-1
lines changed

4 files changed

+13
-1
lines changed

codegen/xla_native_functions.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ full_codegen:
2626
- clamp.Tensor
2727
- clamp_max.Tensor
2828
- clamp_min.Tensor
29+
- _conj_copy
2930
- cos
3031
- cosh
3132
- elu

torch_xla/csrc/ops/ops_lower_fn.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,11 @@ torch_xla::XlaOpVector ClampMinTensor::Lower(LoweringContext* loctx) const {
342342
return ReturnOp(xla::Max(xla_input, xla_other), loctx);
343343
}
344344

345+
torch_xla::XlaOpVector ConjCopy::Lower(LoweringContext* loctx) const {
346+
xla::XlaOp input = loctx->GetOutputOp(operand(0));
347+
return ReturnOp(xla::Conj(input), loctx);
348+
}
349+
345350
torch_xla::XlaOpVector Cos::Lower(LoweringContext* loctx) const {
346351
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
347352
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,10 @@ xla::Shape ClampMinTensorOutputShape(const torch::lazy::Value& input,
424424
lower_for_shape_fn);
425425
}
426426

427+
xla::Shape ConjCopyOutputShape(const torch::lazy::Value& input) {
428+
return GetXlaShape(input);
429+
}
430+
427431
xla::Shape CosOutputShape(const torch::lazy::Value& input) {
428432
xla::Shape result_shape = GetXlaShape(input);
429433
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {

torch_xla/csrc/ops/ops_xla_shape_fn.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ xla::Shape ClampMaxTensorOutputShape(const torch::lazy::Value& input,
108108
xla::Shape ClampMinTensorOutputShape(const torch::lazy::Value& input,
109109
const torch::lazy::Value& target);
110110

111+
xla::Shape ConjCopyOutputShape(const torch::lazy::Value& input);
112+
111113
xla::Shape CosOutputShape(const torch::lazy::Value& input);
112114

113115
xla::Shape CoshOutputShape(const torch::lazy::Value& input);
@@ -287,4 +289,4 @@ xla::Shape TruncOutputShape(const torch::lazy::Value& input);
287289

288290
} // namespace torch_xla
289291

290-
#endif // XLA_TORCH_XLA_CSRC_OPS_OPS_XLA_SHAPE_FN_H_
292+
#endif // XLA_TORCH_XLA_CSRC_OPS_OPS_XLA_SHAPE_FN_H_

0 commit comments

Comments
 (0)