Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 227 additions & 0 deletions src/accel/LocalOpticalTrackOffload.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
//------------------------------- -*- C++ -*- -------------------------------//
// Copyright Celeritas contributors: see top-level COPYRIGHT file for details
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file accel/LocalOpticalTrackOffload.cc
//---------------------------------------------------------------------------//
#include "LocalOpticalTrackOffload.hh"

#include <G4EventManager.hh>
#include <G4MTRunManager.hh>

#include "corecel/sys/ScopedProfiling.hh"
#include "geocel/GeantUtils.hh"
#include "celeritas/global/CoreParams.hh"
#include "celeritas/optical/CoreParams.hh"
#include "celeritas/optical/CoreState.hh"
#include "celeritas/optical/Transporter.hh"

#include "SetupOptions.hh"
#include "SharedParams.hh"

namespace celeritas
{
//---------------------------------------------------------------------------//
/*!
*
*/
LocalOpticalTrackOffload::LocalOpticalTrackOffload(SetupOptions const& options,
SharedParams& params)
{
CELER_VALIDATE(params.mode() == SharedParams::Mode::enabled,
<< "cannot create local optical track offload when "
"Celeritas "
"offloading is disabled");

// Check the thread ID and MT model
validate_geant_threading(params.Params()->max_streams());

// Save a pointer to the optical transporter
transport_ = params.optical_transporter();
CELER_ASSERT(transport_);

CELER_ASSERT(transport_->params());
auto const& optical_params = *transport_->params();

// Number of optical tracks to buffer before offloading
auto const& capacity = options.optical->capacity;
auto_flush_ = capacity.tracks;

auto stream_id = id_cast<StreamId>(get_geant_thread_id());

// Allocate thread-local state data
auto memspace = celeritas::device() ? MemSpace::device : MemSpace::host;
if (memspace == MemSpace::device)
{
state_ = std::make_shared<optical::CoreState<MemSpace::device>>(
optical_params, stream_id, capacity.tracks);
}
else
{
state_ = std::make_shared<optical::CoreState<MemSpace::host>>(
optical_params, stream_id, capacity.tracks);
}

// Allocate auxiliary data
if (params.Params()->aux_reg())
{
state_->aux() = std::make_shared<AuxStateVec>(
*params.Params()->aux_reg(), memspace, stream_id, capacity.tracks);
}

CELER_ENSURE(*this);
}

//---------------------------------------------------------------------------//
/*!
* Initialize with options and shared data.
*/
void LocalOpticalTrackOffload::Initialize(SetupOptions const& options,
SharedParams& params)
{
*this = LocalOpticalTrackOffload(options, params);
}

//---------------------------------------------------------------------------//
/*!
* Set the event ID and reseed the Celeritas RNG at the start of an event.
*/
void LocalOpticalTrackOffload::InitializeEvent(int id)
{
CELER_EXPECT(*this);
CELER_EXPECT(id >= 0);

event_id_ = id_cast<UniqueEventId>(id);

if (!(G4Threading::IsMultithreadedApplication()
&& G4MTRunManager::SeedOncePerCommunication()))
{
// Since Geant4 schedules events dynamically, reseed the Celeritas RNGs
// using the Geant4 event ID for reproducibility. This guarantees that
// an event can be reproduced given the event ID.
state_->reseed(transport_->params()->rng(), id_cast<UniqueEventId>(id));
}
}

//---------------------------------------------------------------------------//
/*!
* Buffer distribution data for generating optical photons.
*/
void LocalOpticalTrackOffload::Push(G4Track const& g4track)
{
CELER_EXPECT(*this);
TrackData init;

// Sanity check: this path is meant for optical photons
CELER_EXPECT(g4track.GetDefinition());
CELER_EXPECT(g4track.GetDefinition()->GetParticleName() == "opticalphoton");

// Energy: convert Geant4 energy [MeV] to Celeritas MevEnergy
init.energy = units::MevEnergy{g4track.GetTotalEnergy() / CLHEP::MeV};

// Position: Geant4 uses mm; Celeritas uses cm
auto const& pos = g4track.GetPosition();
init.position
= Real3{pos.x() / CLHEP::cm, pos.y() / CLHEP::cm, pos.z() / CLHEP::cm};

auto const& dir = g4track.GetMomentumDirection();
init.direction = Real3{dir.x(), dir.y(), dir.z()};

// Polarization: directly from G4
auto const& pol = g4track.GetPolarization();
init.polarization = Real3{pol.x(), pol.y(), pol.z()};

// Time: Geant4 uses ns; Celeritas uses seconds
init.time = g4track.GetGlobalTime() / CLHEP::s;

ScopedProfiling profile_this{"push"};

buffer_.push_back(init);
pending_tracks_++;

if (pending_tracks_ >= auto_flush_)
{
this->Flush();
}
}
//---------------------------------------------------------------------------//
/*!
* Generate and transport optical photons from the buffered distribution data.
*/
void LocalOpticalTrackOffload::Flush()
{
CELER_EXPECT(*this);

if (buffer_.empty())
{
return;
}

ScopedProfiling profile_this("flush");

//! \todo Duplicated in \c LocalTransporter
if (event_manager_ || !event_id_)
{
if (CELER_UNLIKELY(!event_manager_))
{
// Save the event manager pointer, thereby marking that
// *subsequent* events need to have their IDs checked as well
event_manager_ = G4EventManager::GetEventManager();
CELER_ASSERT(event_manager_);
}

G4Event const* event = event_manager_->GetConstCurrentEvent();
CELER_ASSERT(event);
if (event_id_ != id_cast<UniqueEventId>(event->GetEventID()))
{
// The event ID has changed: reseed it
this->InitializeEvent(event->GetEventID());
}
}
CELER_ASSERT(event_id_);

if (celeritas::device())
{
CELER_LOG_LOCAL(debug)
<< "Transporting " << pending_tracks_
<< " optical track from event " << event_id_.unchecked_get()
<< " with Celeritas";
}
// Inject buffered tracks into optical state

state_->insert_primaries(make_span(buffer_));

pending_tracks_ = 0;
buffer_.clear();

// Generate optical photons and transport to completion
(*transport_)(*state_);
}

//---------------------------------------------------------------------------//
/*!
* Get the accumulated action times.
*/
auto LocalOpticalTrackOffload::GetActionTime() const -> MapStrDbl
{
CELER_EXPECT(*this);
return transport_->get_action_times(*state_->aux());
}
//---------------------------------------------------------------------------//
/*!
* Clear local data.
*/
void LocalOpticalTrackOffload::Finalize()
{
CELER_EXPECT(*this);

CELER_VALIDATE(buffer_.empty(),
<< pending_tracks_ << " optical tracks were not flushed");
// Reset all data
*this = {};

CELER_ENSURE(!*this);
}

//---------------------------------------------------------------------------//
} // namespace celeritas
90 changes: 90 additions & 0 deletions src/accel/LocalOpticalTrackOffload.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
//------------------------------- -*- C++ -*- -------------------------------//
// Copyright Celeritas contributors: see top-level COPYRIGHT file for details
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file accel/LocalOpticalTrackOffload.hh
//---------------------------------------------------------------------------//
#pragma once

#include "corecel/Types.hh"
#include "celeritas/Types.hh"
#include "celeritas/optical/TrackInitializer.hh"
#include "celeritas/optical/Transporter.hh"
#include "accel/TrackOffloadInterface.hh"

#include "TrackOffloadInterface.hh"

class G4EventManager;

namespace celeritas
{
namespace optical
{
class CoreStateBase;
class Transporter;
} // namespace optical

struct SetupOptions;
class SharedParams;

//---------------------------------------------------------------------------//
/*!
* Brief class description.
*
* Optional detailed class description, and possibly example usage:
* \code
LocalOpticalTrackOffload ...;
\endcode
*/
class LocalOpticalTrackOffload final : public TrackOffloadInterface
{
public:
using TrackData = optical::TrackInitializer;
// Construct in an invalid state
LocalOpticalTrackOffload() = default;

// Construct with shared (across threads) params
LocalOpticalTrackOffload(SetupOptions const& options, SharedParams& params);

//!@{
//! \name Type aliases
void Initialize(SetupOptions const&, SharedParams&) final;

// Set the event ID and reseed the Celeritas RNG at the start of an event
void InitializeEvent(int) final;

// Transport all buffered tracks to completion
void Flush() final;

// Clear local data and return to an invalid state
void Finalize() final;

// Whether the class instance is initialized
bool Initialized() const final { return static_cast<bool>(state_); }
// Offload optical distribution data to Celeritas
void Push(G4Track const&) final;
// Number of buffered tracks
size_type GetBufferSize() const final { return pending_tracks_; }

// Get accumulated action times
MapStrDbl GetActionTime() const final;
//!@}

private:
// Transport pending optical tracks
std::shared_ptr<optical::Transporter> transport_;
// Thread-local state data
std::shared_ptr<optical::CoreStateBase> state_;

std::vector<TrackData> buffer_;
size_type pending_tracks_{};
// Number of photons to buffer before offloading
size_type auto_flush_{};

// Current event ID or manager for obtaining it
UniqueEventId event_id_;
G4EventManager* event_manager_{nullptr};
};

//---------------------------------------------------------------------------//
} // namespace celeritas
6 changes: 3 additions & 3 deletions src/accel/LocalTransporter.hh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "celeritas/Types.hh"
#include "celeritas/phys/Primary.hh"

#include "LocalOffloadInterface.hh"
#include "TrackOffloadInterface.hh"

class G4EventManager;
class G4Track;
Expand Down Expand Up @@ -52,7 +52,7 @@ class StepperInterface;
*
* \todo Rename \c LocalOffload or something?
*/
class LocalTransporter final : public LocalOffloadInterface
class LocalTransporter final : public TrackOffloadInterface
{
public:
// Construct in an invalid state
Expand Down Expand Up @@ -88,7 +88,7 @@ class LocalTransporter final : public LocalOffloadInterface
//!@}

// Offload this track
void Push(G4Track&);
void Push(G4Track const&) override;

// Set the event ID and reseed the Celeritas RNG (remove in v0.6)
[[deprecated]] void SetEventId(int id) { this->InitializeEvent(id); }
Expand Down
3 changes: 2 additions & 1 deletion src/accel/SetupOptions.hh
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ struct OpticalSetupOptions
inp::OpticalGenerator generator;
//! Limit on number of optical step iterations before aborting
size_type max_step_iters{inp::TrackingLimits::unlimited};

bool offload_tracks{false};
};

//---------------------------------------------------------------------------//
Expand Down Expand Up @@ -191,7 +193,6 @@ struct SetupOptions

//!@{
//! \name Optical photon options

std::optional<OpticalSetupOptions> optical;
//!@}

Expand Down
35 changes: 35 additions & 0 deletions src/accel/TrackOffloadInterface.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//------------------------------- -*- C++ -*- -------------------------------//
// Copyright Celeritas contributors: see top-level COPYRIGHT file for details
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
//---------------------------------------------------------------------------//
//! \file accel/TrackOffloadInterface.hh
//---------------------------------------------------------------------------//
#pragma once

#include "LocalOffloadInterface.hh"

class G4Track;

namespace celeritas
{
//---------------------------------------------------------------------------//
/*!
* Brief class description.
*
* Optional detailed class description, and possibly example usage:
* \code
TrackOffloadInterface ...;
\endcode
*/
class TrackOffloadInterface : public LocalOffloadInterface
{
public:
// Construct with defaults
virtual ~TrackOffloadInterface() = default;

// Push a full Geant4 track to Celeritas
virtual void Push(G4Track const&) = 0;
};

//---------------------------------------------------------------------------//
} // namespace celeritas
Loading
Loading