From bad8d35df795d3ab92e3e784c34dc83b2c6cc1e9 Mon Sep 17 00:00:00 2001 From: Christian Haack Date: Fri, 24 Jan 2025 18:54:44 +0100 Subject: [PATCH] refactor scattering functions to use Distributions.rand and update MockMediumProperties structure --- Manifest.toml | 2 +- src/scattering.jl | 23 +++++++++++++---------- test/runtests.jl | 31 +++++++++++++++---------------- 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 8e52692..7f0fcc3 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.11.2" manifest_format = "2.0" -project_hash = "557cf6a38deb8fdca5c3514188b58ef73cf03842" +project_hash = "4663a998cc3e22049c19f32db65bbb7245db1898" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] diff --git a/src/scattering.jl b/src/scattering.jl index efe0615..455ff85 100644 --- a/src/scattering.jl +++ b/src/scattering.jl @@ -1,6 +1,5 @@ -import Distributions -using Distributions: Sampleable, Univariate, Continuous -using Random +using Distributions +using Random: AbstractRNG using Polynomials: fit, Polynomial, ImmutablePolynomial export AbstractScatteringFunction @@ -18,7 +17,7 @@ export WavelengthIndependentScatteringModel abstract type AbstractScatteringFunction <: Sampleable{Univariate, Continuous} end -rand(rng::AbstractRNG, s::AbstractScatteringFunction) = _not_implemented(s) +Distributions.rand(rng::AbstractRNG, s::AbstractScatteringFunction) = _not_implemented(s) """ @@ -48,7 +47,7 @@ function _hg_scattering_func(rng::AbstractRNG, g::T) where {T <: Real} return clamp(costheta, T(-1), T(1)) end -Base.rand(rng::AbstractRNG, s::HenyeyGreenStein) = _hg_scattering_func(rng, s.g) +Distributions.rand(rng::AbstractRNG, s::HenyeyGreenStein) = _hg_scattering_func(rng, s.g) """ SimplifiedLiu{T} @@ -74,7 +73,7 @@ function sl_scattering_func(rng::AbstractRNG, g::T) where {T <: Real} return clamp(costheta, T(-1), T(1)) end -Base.rand(rng::AbstractRNG, s::SimplifiedLiu) = sl_scattering_func(rng, s.g) +Distributions.rand(rng::AbstractRNG, s::SimplifiedLiu) = sl_scattering_func(rng, s.g) """ @@ -86,7 +85,7 @@ struct PolynomialScatteringFunction{T, P <: ImmutablePolynomial{T}} <: AbstractS poly::P end -function Base.rand(rng::AbstractRNG, s::PolynomialScatteringFunction{T}) where {T} +function Distributions.rand(rng::AbstractRNG, s::PolynomialScatteringFunction{T}) where {T} eta = Base.rand(rng, T) return clamp(s.poly(eta), T(-1), T(1)) end @@ -152,7 +151,7 @@ MixedHGES(g, b, fraction) = TwoComponentScatteringFunction(HenyeyGreenStein(g), MixedHGSL(g, fraction) = TwoComponentScatteringFunction(HenyeyGreenStein(g), SimplifiedLiu(g), fraction) -function Base.rand(rng::AbstractRNG, s::TwoComponentScatteringFunction) +function Distributions.rand(rng::AbstractRNG, s::TwoComponentScatteringFunction) choice = Base.rand(rng, Float64) if choice < s.fraction return Base.rand(rng, s.f1) @@ -194,7 +193,7 @@ get_scattering_function(model::AbstractScatteringModel) = _not_implemented(model function sample_scattering_function(model::AbstractScatteringModel) func = get_scattering_function(model) - return rand(func) + return Distributions.rand(func) end """ @@ -221,6 +220,8 @@ function scattering_length(model::KopelevichScatteringModel, wavelength::Real) ) end +get_scattering_function(model::KopelevichScatteringModel) = model.scattering_function + """ WavelengthIndependentScatteringModel{T, F<:AbstractScatteringFunction} <: AbstractScatteringModel @@ -234,4 +235,6 @@ end function scattering_length(model::WavelengthIndependentScatteringModel, wavelength::Real) return model.scattering_length -end \ No newline at end of file +end + +get_scattering_function(model::WavelengthIndependentScatteringModel) = model.scattering_function \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 5b597b4..86a4c13 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using CherenkovMediumBase using CherenkovMediumBase: es_scattering, es_scattering_integral, es_scattering_cumulative using Test using Random +using StaticArrays Random.seed!(1234) @@ -34,28 +35,26 @@ struct MockMediumProperties <: MediumProperties pressure::Float64 dispersion_model::QuanFryDispersion scattering_model::KopelevichScatteringModel - MockMediumProperties(salinity, temperature, pressure) = new(salinity, temperature, pressure, QuanFryDispersion(salinity, temperature, pressure), KopelevichScatteringModel(HenyeyGreenStein(0.8), 0.1, 0.2)) + absoption_model::InterpolatedAbsorptionModel + function MockMediumProperties(salinity, temperature, pressure) + return new( + salinity, + temperature, + pressure, + QuanFryDispersion(salinity, temperature, pressure), + KopelevichScatteringModel(HenyeyGreenStein(0.8), 0.1, 0.2), + InterpolatedAbsorptionModel(SA[1., 2], SA[3., 4]) + ) + end end +CherenkovMediumBase.get_absorption_model(medium::MockMediumProperties) = medium.absoption_model +CherenkovMediumBase.get_scattering_model(medium::MockMediumProperties) = medium.scattering_model +CherenkovMediumBase.get_dispersion_model(medium::MockMediumProperties) = medium.dispersion_model CherenkovMediumBase.pressure(medium::MockMediumProperties) = medium.pressure CherenkovMediumBase.temperature(medium::MockMediumProperties) = medium.temperature CherenkovMediumBase.material_density(medium::MockMediumProperties) = 3. -CherenkovMediumBase.absorption_length(medium::MockMediumProperties, wavelength) = 6. -CherenkovMediumBase.sample_scattering_function(medium::MockMediumProperties) = rand(medium.scattering_model.scattering_function) - - -function CherenkovMediumBase.phase_refractive_index(medium::MockMediumProperties, wavelength) - return phase_refractive_index(medium.dispersion_model, wavelength) -end - -function CherenkovMediumBase.dispersion(medium::MockMediumProperties, wavelength) - return dispersion(medium.dispersion_model, wavelength) -end - -function CherenkovMediumBase.scattering_length(medium::MockMediumProperties, wavelength) - return scattering_length(medium.scattering_model, wavelength) -end @testset "CherenkovMediumBase.jl" begin