Skip to content

Commit 899cec0

Browse files
Merge pull request #20 from utkarsh530/solverstats
MATLAB solver stats
2 parents 43579fd + 6e63c4b commit 899cec0

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "MATLABDiffEq"
22
uuid = "e2752cbe-bcf4-5895-8727-84ebc14a76bd"
3-
version = "0.3.1"
3+
version = "0.3.2"
44

55
[deps]
66
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

src/MATLABDiffEq.jl

+20-3
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,12 @@ function DiffEqBase.__solve(
6262

6363
eval_string("options = odeset('RelTol',reltol,'AbsTol',abstol);")
6464
algstr = string(typeof(alg).name.name)
65-
#algstr = replace(string(typeof(alg)),"MATLABDiffEq.","")
66-
eval_string("[t,u] = $(algstr)(diffeqf,tspan,u0,options);")
65+
eval_string("mxsol = $(algstr)(diffeqf,tspan,u0,options);")
66+
eval_string("mxsolstats = struct(mxsol.stats);")
67+
solstats = get_variable(:mxsolstats)
68+
eval_string("t = mxsol.x;")
6769
ts = jvector(get_mvariable(:t))
70+
eval_string("u = mxsol.y';")
6871
timeseries_tmp = jarray(get_mvariable(:u))
6972

7073
# Reshape the result if needed
@@ -77,8 +80,22 @@ function DiffEqBase.__solve(
7780
timeseries = timeseries_tmp
7881
end
7982

83+
destats = buildDEStats(solstats)
84+
8085
DiffEqBase.build_solution(prob,alg,ts,timeseries,
81-
timeseries_errors = timeseries_errors)
86+
timeseries_errors = timeseries_errors,destats = destats)
87+
end
88+
89+
function buildDEStats(solverstats::Dict)
90+
91+
destats = DiffEqBase.DEStats(0)
92+
destats.nf = if (haskey(solverstats, "nfevals")) solverstats["nfevals"] else 0 end
93+
destats.nreject = if (haskey(solverstats, "nfailed")) solverstats["nfailed"] else 0 end
94+
destats.naccept = if (haskey(solverstats, "nsteps")) solverstats["nsteps"] else 0 end
95+
destats.nsolve = if (haskey(solverstats, "nsolves")) solverstats["nsolves"] else 0 end
96+
destats.njacs = if (haskey(solverstats, "npds")) solverstats["npds"] else 0 end
97+
destats.nw = if (haskey(solverstats, "ndecomps")) solverstats["ndecomps"] else 0 end
98+
destats
8299
end
83100

84101
end # module

0 commit comments

Comments
 (0)