-
Notifications
You must be signed in to change notification settings - Fork 15
RFC Spline Generalization
Dial calculations require the calculation of weights based on the parameter value. This is currently being done almost exclusively using TSpline3 to interpolate between knot points. Going forward we are going to need different types of interpolation, particularly different types of splines, but also the ability to wrap other functions (e.g. oscillation weights).
This RFC covers an abstract base class (ABC) for weight calculations that can be used by dials.
-
Chose a name generic name that captures what is done, not how it is done. Rational: This is calculating a function weight. We could use
Splineto stick with current T2K jargon, but that is a method of implementation, not what is being done. Other, possible base names might beCalculator,Function, etc. -
Have low overhead in the calculation (be fast). The ABC should be a very light wrapper around the calculations. Rational: The dials heavily use the weight calculations, so it needs to be as fast as possible.
-
Easy to use for one dimension, but support multiple dimensions. Rational: The dials are currently only one dimensional, but there are use cases where we need multiple dimensions. A simple example is if we need to calculate neutrino oscillation weights for some analysis.
-
Provide an optional interface to calculate the gradient (or derivative). Rational: Derivatives are easy to calculate for most interpolation functions. The structure of the GUNDAM likelihood makes it fairly easy to calculate the full gradient if the dial derivatives can be calculated. Adding a gradient interface future proofs the API.
-
Do not use a "switchyard" in the class to chose the type of spline to be used. Rational: The calculation needs to be quick and a switchyard will add a separate layer of branching.
-
Derived classes support different types of calculations (e.g. TSpline3, uniform splines, monotonic splines, other weighting functions). Rational: Each derived class can uses different method, and a dial can chose the type of calculation during construction. This means that during calculation, the correct virtual method will be "directly" called.
-
Use a factory pattern to create the objects based on input that would be saved in a YAML file. That probably means that the spline class gets created by a static factory class that takes a string to pick the class of the created object. Rational: GUNDAM relies on YAML for configuration, so having a string interface will make this interface more directly with the GUNDAM way of doing things.
-
Simplify having a (mostly) unified implementation of spline calculations on CPU and GPU. Rational and Commentary: As GPU based calculations become more common, we need to make it as easy as possible to use both GPU and CPU based calculations of the same fit. One of the issues is that a GPU can't easily implement a complex classes like TSpline3, but a CUDA GPU spline implementation can be easily ported to GPU using the same source code. This particular desideratum doesn't change the base class, but can be seen as a justification for moving to an ABC with derived instantiation classes. Some of the instantiation class could be designed to easily work on both the CPU and GPU.
This supposes that the classes will be named Function (chosen as a very generic placeholder, but maybe not a good choice). The function signatures are chosen to follow the conventions in TSpline3 (e.g. the evaluation method is named Eval, not eval, or evaluate). An alternative is to follow the typical GUNDAM naming convention which is camelCase with a leading lower case letter (e.g. getDim instead of GetDim, and eval instead of Eval).
This class has the job of creating specific instances of derived Function objects.
Create a pointer to a new Function object with the ownership passed to the caller. The string should be descriptive and could be saved as a field in the dial definition in the YAML. This method should probably be static. While the current specification is variadic, more thought will be needed to figure out how to pass necessary initialization information (overloading may be more appropriate). An example where more information is needed will be passing the TGraph containing the knots for a spline, but different types of Function objects may need different types of information. For most splines/interpolation based objects, the creation could be done Create("name",graph) where graph is a pointer to a TGraph object. A counter example is the calculation of oscillation weights no extra information will be needed.
This defines the interface used to calculate the function values. Any method not implemented in the derived class should throw a std::domain_error("FunctionBase::Method -- not implemented"). Alternatively, the methods could be pure virtual, but that requires the derived class to implement all methods.
Calculate the function value for x. This matches the TSpline3 interface so it makes changing the existing code simpler. It's only valid for one dimensional functions and should throw a std::domain_error("Not a 1D function") when called for a multi-dimensional function.
Provided so we can calculate multidimensional functions. The x parameter would typically be an array, but is left as a pointer so that one dimensional functions could be called.
const double x = 3.0;
FunctionBase* obj = <created-somehow>
double v = obj->Eval(&x);
Calculate the gradient (derivative) of the function.
Provided to match the TSpline3 implementation for one dimensional objects. It's only valid for one dimensional functions and should throw a std::domain_error("Not a 1D function") when called for a multi-dimensional function.
Return the dimensionality of the function (for now, usually 1). This is the number of input parameters.
Return the name of the function calculator.
TGraph *graph = <created-someplace>;
std::unique_ptr<Function> func = FunctionFactory::Create("TSpline3",graph);
double value = func->Eval(1.0);
TGraph *graph = <created-someplace>;
std::unique_ptr<Function> func = FunctionFactory::Create("monotonic",graph);
double value = func->Eval(1.0);
std::unique_ptr<Function> func = FunctionFactory::Create("NeutrinoOscillations");
double x[]{Enu,dm12,dm13,t12,t23,t13,dcp}; // illustrative only.
double value = func(x);