Skip to content

Commit e3885b4

Browse files
committed
Add RAGPlotter class
1 parent 0f311b4 commit e3885b4

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed

stonesoup/plotter.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22
from abc import ABC, abstractmethod
33
from collections.abc import Collection, Iterable
4+
from dataclasses import dataclass
45
from datetime import datetime, timedelta
56
from enum import IntEnum
67
from itertools import chain
@@ -24,12 +25,16 @@
2425
go = None
2526

2627
from .base import Base, Property
28+
from .dataassociator import Associator
29+
from .metricgenerator import MetricGenerator
30+
from .metricgenerator.manager import MultiManager
2731
from .models.base import LinearModel, Model
2832
from .types import detection
2933
from .types.array import StateVector
3034
from .types.groundtruth import GroundTruthPath
3135
from .types.metric import SingleTimeMetric
3236
from .types.state import State, StateMutableSequence
37+
from .types.track import Track
3338
from .types.update import Update
3439

3540

@@ -3579,3 +3584,125 @@ def plot_state_sequence(self, state_sequences, angle_mapping: int, range_mapping
35793584

35803585
frame.data = data_
35813586
frame.traces = traces_
3587+
3588+
3589+
@dataclass
3590+
class RAG:
3591+
r"""Dataclass to store the cutoff values for Red-Amber-Green scoring.
3592+
Values are given as a distance to the metric target value.
3593+
Green is scored if :math:`x \leq` :attr:`GREEN`.
3594+
Amber is scored if :attr:`Green` :math:`< x \leq` :attr:`AMBER`.
3595+
Red is scored if :math:`x >` :attr:`AMBER`.
3596+
"""
3597+
GREEN: float
3598+
AMBER: float
3599+
3600+
3601+
class RAGPlotterly(Plotterly):
3602+
"""
3603+
Plotterly plotter to display tracks according to their performance in a given metric.
3604+
"""
3605+
colours = {0: "black",
3606+
1: "red",
3607+
2: "yellow",
3608+
3: "green"}
3609+
3610+
def __init__(self, metric_name: str, target_value: float,
3611+
rag_boundaries: RAG, *args, **kwargs):
3612+
super().__init__(*args, **kwargs)
3613+
self.metric_name = metric_name
3614+
self.target_value = target_value
3615+
self.rag_boundaries = rag_boundaries
3616+
3617+
@staticmethod
3618+
def generate_metrics(tracks: set[Track], truths: set[GroundTruthPath], associator: Associator,
3619+
metric: MetricGenerator) -> dict:
3620+
"""Method to produce a set of metrics between each track and truth pair
3621+
3622+
Parameters
3623+
----------
3624+
tracks : set[Track]
3625+
truths : set[GroundTruthPath]
3626+
associator : Associator
3627+
Associator used to narrow down track, truth pairs.
3628+
metric : MetricGenerator
3629+
The base metric type with which to produce metrics.
3630+
3631+
Returns
3632+
-------
3633+
dict
3634+
Calculated metrics for track, truth pairs.
3635+
"""
3636+
associations, _ = associator.associated_and_unassociated_tracks(tracks, truths)
3637+
metric_generators = []
3638+
for track in tracks:
3639+
for association in associations:
3640+
if track in association.objects:
3641+
truth = next(iter(association.objects - {track}))
3642+
metric_generators.append(metric(generator_name=(track.id, truth.id),
3643+
tracks_key=track.id,
3644+
truths_key=truth.id))
3645+
metric_manager = MultiManager(metric_generators, associator)
3646+
metric_data = {sms.id: {sms} for sms in [*tracks, *truths]}
3647+
metric_manager.add_data(metric_data)
3648+
metrics = metric_manager.generate_metrics()
3649+
return metrics
3650+
3651+
def get_rag_from_value(self, values: list[float]) -> int:
3652+
"""Method to produce a Red-Amber-Green score based on a given set of metric values.
3653+
3654+
Parameters
3655+
----------
3656+
values : list[float]
3657+
A list of metric values for a given track
3658+
3659+
Returns
3660+
-------
3661+
int
3662+
A score ranging from 1: Red to 3: Green.
3663+
"""
3664+
values = sorted(values, key=lambda x: abs(x-self.target_value))
3665+
value = values[0]
3666+
if abs(value - self.target_value) < self.rag_boundaries.GREEN:
3667+
return 3
3668+
if abs(value - self.target_value) < self.rag_boundaries.AMBER:
3669+
return 2
3670+
return 1
3671+
3672+
def plot_tracks(self, tracks, mapping, metrics, colours=None):
3673+
if colours is None:
3674+
colours = self.colours
3675+
for track in tracks:
3676+
colour_dict = {}
3677+
metric_sets = [metric_set[self.metric_name].value
3678+
for metric_id, metric_set in metrics.items()
3679+
if track.id in metric_id]
3680+
if len(metric_sets) == 0:
3681+
colour_dict = {state.timestamp: 0 for state in track}
3682+
for metric_states in zip(*metric_sets):
3683+
values = [metric_state.value for metric_state in metric_states]
3684+
times = {metric_state.timestamp for metric_state in metric_states}
3685+
assert len(times) == 1
3686+
time = next(iter(times))
3687+
colour_dict[time] = self.get_rag_from_value(values)
3688+
3689+
tracklet = Track(id=track.id)
3690+
tracklet_colour = None
3691+
for state in track.states:
3692+
tracklet.append(state)
3693+
state_colour = colour_dict[state.timestamp]
3694+
if tracklet_colour is None:
3695+
pass
3696+
elif tracklet_colour == state_colour:
3697+
pass
3698+
elif tracklet_colour != state_colour:
3699+
super().plot_tracks(tracklet, mapping, label=track.id,
3700+
marker=dict(color=colours[tracklet_colour]))
3701+
tracklet = Track(id=track.id)
3702+
tracklet.append(state)
3703+
3704+
tracklet.append(state)
3705+
tracklet_colour = state_colour
3706+
if len(tracklet.states) > 0:
3707+
super().plot_tracks(tracklet, mapping, label=track.id,
3708+
marker=dict(color=colours[tracklet_colour]))

0 commit comments

Comments
 (0)