Skip to content

RFC Spline Generalization

Clark McGrew edited this page Jan 25, 2023 · 1 revision

RFC: Generalized function calculation interface

Dial calculations require the calculation of weights based on the parameter value. This is currently being done almost exclusively using TSpline3 to interpolate between knot points. Going forward we are going to need different types of interpolation, particularly different types of splines, but also the ability to wrap other functions (e.g. oscillation weights).

This RFC covers an abstract base class (ABC) for weight calculations that can be used by dials.

Desiderata

  1. Chose a name generic name that captures what is done, not how it is done. Rational: This is calculating a function weight. We could use Spline to stick with current T2K jargon, but that is a method of implementation, not what is being done. Other, possible base names might be Calculator, Function, etc.

  2. Have low overhead in the calculation (be fast). The ABC should be a very light wrapper around the calculations. Rational: The dials heavily use the weight calculations, so it needs to be as fast as possible.

  3. Easy to use for one dimension, but support multiple dimensions. Rational: The dials are currently only one dimensional, but there are use cases where we need multiple dimensions. A simple example is if we need to calculate neutrino oscillation weights for some analysis.

  4. Provide an optional interface to calculate the gradient (or derivative). Rational: Derivatives are easy to calculate for most interpolation functions. The structure of the GUNDAM likelihood makes it fairly easy to calculate the full gradient if the dial derivatives can be calculated. Adding a gradient interface future proofs the API.

  5. Do not use a "switchyard" in the class to chose the type of spline to be used. Rational: The calculation needs to be quick and a switchyard will add a separate layer of branching.

  6. Derived classes support different types of calculations (e.g. TSpline3, uniform splines, monotonic splines, other weighting functions). Rational: Each derived class can uses different method, and a dial can chose the type of calculation during construction. This means that during calculation, the correct virtual method will be "directly" called.

  7. Use a factory pattern to create the objects based on input that would be saved in a YAML file. That probably means that the spline class gets created by a static factory class that takes a string to pick the class of the created object. Rational: GUNDAM relies on YAML for configuration, so having a string interface will make this interface more directly with the GUNDAM way of doing things.

  8. Simplify having a (mostly) unified implementation of spline calculations on CPU and GPU. Rational and Commentary: As GPU based calculations become more common, we need to make it as easy as possible to use both GPU and CPU based calculations of the same fit. One of the issues is that a GPU can't easily implement a complex classes like TSpline3, but a CUDA GPU spline implementation can be easily ported to GPU using the same source code. This particular desideratum doesn't change the base class, but can be seen as a justification for moving to an ABC with derived instantiation classes. Some of the instantiation class could be designed to easily work on both the CPU and GPU.

Proposed class signatures

This supposes that the classes will be named Function (chosen as a very generic placeholder, but maybe not a good choice). The function signatures are chosen to follow the conventions in TSpline3 (e.g. the evaluation method is named Eval, not eval, or evaluate). An alternative is to follow the typical GUNDAM naming convention which is camelCase with a leading lower case letter (e.g. getDim instead of GetDim, and eval instead of Eval).

Factory class: FunctionFactory

This class has the job of creating specific instances of derived Function objects.

public method: Function* FunctionFactory::Create(std::string name,...)

Create a pointer to a new Function object with the ownership passed to the caller. The string should be descriptive and could be saved as a field in the dial definition in the YAML. This method should probably be static. While the current specification is variadic, more thought will be needed to figure out how to pass necessary initialization information (overloading may be more appropriate). An example where more information is needed will be passing the TGraph containing the knots for a spline, but different types of Function objects may need different types of information. For most splines/interpolation based objects, the creation could be done Create("name",graph) where graph is a pointer to a TGraph object. A counter example is the calculation of oscillation weights no extra information will be needed.

Abstract base class for function calculation: FunctionBase

This defines the interface used to calculate the function values. Any method not implemented in the derived class should throw a std::domain_error("FunctionBase::Method -- not implemented"). Alternatively, the methods could be pure virtual, but that requires the derived class to implement all methods.

public method: double FunctionBase::Eval(double x) const

Calculate the function value for x. This matches the TSpline3 interface so it makes changing the existing code simpler. It's only valid for one dimensional functions and should throw a std::domain_error("Not a 1D function") when called for a multi-dimensional function.

public method: double FunctionBase::Eval(const double* x) const

Provided so we can calculate multidimensional functions. The x parameter would typically be an array, but is left as a pointer so that one dimensional functions could be called.

const double x = 3.0;
FunctionBase* obj = <created-somehow>
double v = obj->Eval(&x);

public method: void FunctionBase::Gradient(const double* x, double* grad) const

Calculate the gradient (derivative) of the function.

public method: double FunctionBase::Derivative(double x) const

Provided to match the TSpline3 implementation for one dimensional objects. It's only valid for one dimensional functions and should throw a std::domain_error("Not a 1D function") when called for a multi-dimensional function.

public method: int FunctionBase::GetDim() const

Return the dimensionality of the function (for now, usually 1). This is the number of input parameters.

public method: std::string FunctionBase::GetName() const

Return the name of the function calculator.

Use case examples

Setup and call a TSpline3

TGraph *graph = <created-someplace>;
std::unique_ptr<Function> func = FunctionFactory::Create("TSpline3",graph);
double value = func->Eval(1.0);

Setup and call a "monotonic" spline.

TGraph *graph = <created-someplace>;
std::unique_ptr<Function> func = FunctionFactory::Create("monotonic",graph);
double value = func->Eval(1.0);

Setup and call an oscillation weight calculation

std::unique_ptr<Function> func = FunctionFactory::Create("NeutrinoOscillations");
double x[]{Enu,dm12,dm13,t12,t23,t13,dcp}; // illustrative only.
double value = func(x);