Skip to content

Commit ef31e35

Browse files
committed
Add attribute builder accessors for fits
1 parent a756490 commit ef31e35

File tree

6 files changed

+70
-11
lines changed

6 files changed

+70
-11
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Unreleased
44

55
### Added
6+
- Fit accessors with Attribute
67

78
### Changed
89
- Upgrade tensorflow version to 1.0.0

examples/src/main/kotlin/space/kscience/kmath/fit/chiSquared.kt

+8-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ import space.kscience.kmath.expressions.autodiff
1414
import space.kscience.kmath.expressions.symbol
1515
import space.kscience.kmath.operations.asIterable
1616
import space.kscience.kmath.operations.toList
17-
import space.kscience.kmath.optimization.*
17+
import space.kscience.kmath.optimization.minimize
18+
import space.kscience.kmath.optimization.optimizeWith
19+
import space.kscience.kmath.optimization.result
20+
import space.kscience.kmath.optimization.resultValue
1821
import space.kscience.kmath.random.RandomGenerator
1922
import space.kscience.kmath.real.DoubleVector
2023
import space.kscience.kmath.real.map
@@ -79,9 +82,10 @@ suspend fun main() {
7982
val result = chi2.optimizeWith(
8083
CMOptimizer,
8184
mapOf(a to 1.5, b to 0.9, c to 1.0),
82-
) {
83-
FunctionOptimizationTarget(OptimizationDirection.MINIMIZE)
84-
}
85+
attributesBuilder = {
86+
minimize()
87+
}
88+
)
8589

8690
//display a page with plot and numerical results
8791
val page = Plotly.page {

examples/src/main/kotlin/space/kscience/kmath/fit/qowFit.kt

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ package space.kscience.kmath.fit
77

88
import kotlinx.html.br
99
import kotlinx.html.h3
10-
import space.kscience.attributes.Attributes
1110
import space.kscience.kmath.data.XYErrorColumnarData
1211
import space.kscience.kmath.distributions.NormalDistribution
1312
import space.kscience.kmath.expressions.Symbol
@@ -65,7 +64,9 @@ suspend fun main() {
6564
QowOptimizer,
6665
Double.autodiff,
6766
mapOf(a to 0.9, b to 1.2, c to 2.0, e to 1.0, d to 1.0, e to 0.0),
68-
attributes = Attributes(OptimizationParameters, listOf(a, b, c, d))
67+
attributesBuilder = {
68+
freeParameters(a, b, c, d)
69+
},
6970
) { arg ->
7071
//bind variables to autodiff context
7172
val a by binding

kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/FunctionOptimization.kt

+13-4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ public enum class OptimizationDirection {
2222

2323
public object FunctionOptimizationTarget : OptimizationAttribute<OptimizationDirection>
2424

25+
public fun AttributesBuilder<FunctionOptimization<*>>.maximize() {
26+
FunctionOptimizationTarget(OptimizationDirection.MAXIMIZE)
27+
}
28+
29+
public fun AttributesBuilder<FunctionOptimization<*>>.minimize() {
30+
FunctionOptimizationTarget(OptimizationDirection.MINIMIZE)
31+
}
32+
33+
2534
public class FunctionOptimization<T>(
2635
public val expression: DifferentiableExpression<T>,
2736
override val attributes: Attributes,
@@ -74,11 +83,11 @@ public fun <T> FunctionOptimization<T>.withAttributes(
7483
public suspend fun <T> DifferentiableExpression<T>.optimizeWith(
7584
optimizer: Optimizer<T, FunctionOptimization<T>>,
7685
startingPoint: Map<Symbol, T>,
77-
modifier: AttributesBuilder<FunctionOptimization<T>>.() -> Unit = {},
86+
attributesBuilder: AttributesBuilder<FunctionOptimization<T>>.() -> Unit = {},
7887
): FunctionOptimization<T> {
7988
val problem = FunctionOptimization(this) {
8089
startAt(startingPoint)
81-
modifier()
90+
attributesBuilder()
8291
}
8392
return optimizer.optimize(problem)
8493
}
@@ -93,11 +102,11 @@ public val <T> FunctionOptimization<T>.resultValue: T
93102
public suspend fun <T> DifferentiableExpression<T>.optimizeWith(
94103
optimizer: Optimizer<T, FunctionOptimization<T>>,
95104
vararg startingPoint: Pair<Symbol, T>,
96-
builder: AttributesBuilder<FunctionOptimization<T>>.() -> Unit = {},
105+
attributesBuilder: AttributesBuilder<FunctionOptimization<T>>.() -> Unit = {},
97106
): FunctionOptimization<T> {
98107
val problem = FunctionOptimization<T>(this) {
99108
startAt(mapOf(*startingPoint))
100-
builder()
109+
attributesBuilder()
101110
}
102111
return optimizer.optimize(problem)
103112
}

kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/OptimizationProblem.kt

+7-1
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,15 @@ public object OptimizationLog : OptimizationAttribute<Loggable>
5858
*/
5959
public object OptimizationParameters : OptimizationAttribute<List<Symbol>>
6060

61+
public fun AttributesBuilder<OptimizationProblem<*>>.freeParameters(vararg symbols: Symbol) {
62+
OptimizationParameters(symbols.asList())
63+
}
64+
6165
/**
6266
* Maximum allowed number of iterations
6367
*/
6468
public object OptimizationIterations : OptimizationAttribute<Int>
6569

66-
70+
public fun AttributesBuilder<OptimizationProblem<*>>.iterations(iterations: Int) {
71+
OptimizationIterations(iterations)
72+
}

kmath-optimization/src/commonMain/kotlin/space/kscience/kmath/optimization/XYFit.kt

+38
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,24 @@ public suspend fun XYColumnarData<Double, Double, Double>.fitWith(
137137
return optimizer.optimize(problem)
138138
}
139139

140+
public suspend fun XYColumnarData<Double, Double, Double>.fitWith(
141+
optimizer: Optimizer<Double, XYFit>,
142+
modelExpression: DifferentiableExpression<Float64>,
143+
startingPoint: Map<Symbol, Double>,
144+
attributesBuilder: AttributesBuilder<XYFit>.() -> Unit,
145+
xSymbol: Symbol = Symbol.x,
146+
pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY,
147+
pointWeight: PointWeight = PointWeight.byYSigma,
148+
): XYFit = fitWith(
149+
optimizer = optimizer,
150+
modelExpression = modelExpression,
151+
startingPoint = startingPoint,
152+
attributes = Attributes<XYFit>(attributesBuilder),
153+
xSymbol = xSymbol,
154+
pointToCurveDistance = pointToCurveDistance,
155+
pointWeight = pointWeight
156+
)
157+
140158
/**
141159
* Fit given data with a model provided as an expression
142160
*/
@@ -166,6 +184,26 @@ public suspend fun <I : Any, A> XYColumnarData<Double, Double, Double>.fitWith(
166184
)
167185
}
168186

187+
public suspend fun <I : Any, A> XYColumnarData<Double, Double, Double>.fitWith(
188+
optimizer: Optimizer<Double, XYFit>,
189+
processor: AutoDiffProcessor<Double, I, A>,
190+
startingPoint: Map<Symbol, Double>,
191+
attributesBuilder: AttributesBuilder<XYFit>.() -> Unit,
192+
xSymbol: Symbol = Symbol.x,
193+
pointToCurveDistance: PointToCurveDistance = PointToCurveDistance.byY,
194+
pointWeight: PointWeight = PointWeight.byYSigma,
195+
model: A.(I) -> I,
196+
): XYFit where A : ExtendedField<I>, A : ExpressionAlgebra<Double, I> = fitWith(
197+
optimizer = optimizer,
198+
processor = processor,
199+
startingPoint = startingPoint,
200+
attributes = Attributes<XYFit>(attributesBuilder),
201+
xSymbol = xSymbol,
202+
pointToCurveDistance = pointToCurveDistance,
203+
pointWeight = pointWeight,
204+
model = model
205+
)
206+
169207
/**
170208
* Compute chi squared value for completed fit. Return null for incomplete fit
171209
*/

0 commit comments

Comments
 (0)