Skip to content

What's the right way to construct custom ops with the same name but different output types? #23891

Closed
@yuanyao-nv

Description

@yuanyao-nv

I'm following the examples here to construct custom ops. In my particular use case I want to use an attribute to determine the output type of the custom op, similar to how the standard Cast op chooses its output type depending on the "to" attribute. However, if I register several custom ops with the same name but different output types, anything other than the first registered version will give a type inference error.

Here's a sketch of my code:

template <typename out_t>
struct KernelCast {
  int32_t output_dtype_attr;

  KernelCast(const OrtApi* ort_api, const OrtKernelInfo* info) {
    int64_t output_dtype = 1;
    if(ort_api->KernelInfoGetAttribute_int64(info, "output_dtype", &output_dtype) != nullptr)
    {
      std::cout << "Read output_dtype attr error, using default value of 1" << std::endl;
    }
    output_dtype_attr = output_dtype;
  }

  void Compute(const Ort::Custom::Tensor<float_t> &x_,
                Ort::Custom::Tensor<out_t> &y_)
  {
    // Do something
  }
};

void RegisterOps(Ort::CustomOpDomain &domain)
{
  static const std::unique_ptr<OrtLiteCustomOp> c_CustomOpCastFloat16{Ort::Custom::CreateLiteCustomOp<KernelCast<Ort::Float16_t>>("CustomCast", "CPUExecutionProvider")};
  domain.Add(c_CustomOpCastFloat16.get());
  static const std::unique_ptr<OrtLiteCustomOp> c_CustomOpCastFloat{Ort::Custom::CreateLiteCustomOp<KernelCast<float>>("CustomCast", "CPUExecutionProvider")};
  domain.Add(c_CustomOpCastFloat.get());
}

And here's the error: onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Type Error: Type (tensor(float)) of output arg (y) of node () does not match expected type (tensor(float16)).

What's the right way to let it dispatch to different output type kernels depending on the attribute? Any examples available to reference?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions