|
1 | 1 | import warnings |
2 | 2 | from abc import ABC, abstractmethod |
3 | 3 | from collections.abc import Collection, Iterable |
| 4 | +from dataclasses import dataclass |
4 | 5 | from datetime import datetime, timedelta |
5 | 6 | from enum import IntEnum |
6 | 7 | from itertools import chain |
|
24 | 25 | go = None |
25 | 26 |
|
26 | 27 | from .base import Base, Property |
| 28 | +from .dataassociator import Associator |
| 29 | +from .metricgenerator import MetricGenerator |
| 30 | +from .metricgenerator.manager import MultiManager |
27 | 31 | from .models.base import LinearModel, Model |
28 | 32 | from .types import detection |
29 | 33 | from .types.array import StateVector |
30 | 34 | from .types.groundtruth import GroundTruthPath |
31 | 35 | from .types.metric import SingleTimeMetric |
32 | 36 | from .types.state import State, StateMutableSequence |
| 37 | +from .types.track import Track |
33 | 38 | from .types.update import Update |
34 | 39 |
|
35 | 40 |
|
@@ -3579,3 +3584,125 @@ def plot_state_sequence(self, state_sequences, angle_mapping: int, range_mapping |
3579 | 3584 |
|
3580 | 3585 | frame.data = data_ |
3581 | 3586 | frame.traces = traces_ |
| 3587 | + |
| 3588 | + |
| 3589 | +@dataclass |
| 3590 | +class RAG: |
| 3591 | + r"""Dataclass to store the cutoff values for Red-Amber-Green scoring. |
| 3592 | + Values are given as a distance to the metric target value. |
| 3593 | + Green is scored if :math:`x \leq` :attr:`GREEN`. |
| 3594 | + Amber is scored if :attr:`Green` :math:`< x \leq` :attr:`AMBER`. |
| 3595 | + Red is scored if :math:`x >` :attr:`AMBER`. |
| 3596 | + """ |
| 3597 | + GREEN: float |
| 3598 | + AMBER: float |
| 3599 | + |
| 3600 | + |
| 3601 | +class RAGPlotterly(Plotterly): |
| 3602 | + """ |
| 3603 | + Plotterly plotter to display tracks according to their performance in a given metric. |
| 3604 | + """ |
| 3605 | + colours = {0: "black", |
| 3606 | + 1: "red", |
| 3607 | + 2: "yellow", |
| 3608 | + 3: "green"} |
| 3609 | + |
| 3610 | + def __init__(self, metric_name: str, target_value: float, |
| 3611 | + rag_boundaries: RAG, *args, **kwargs): |
| 3612 | + super().__init__(*args, **kwargs) |
| 3613 | + self.metric_name = metric_name |
| 3614 | + self.target_value = target_value |
| 3615 | + self.rag_boundaries = rag_boundaries |
| 3616 | + |
| 3617 | + @staticmethod |
| 3618 | + def generate_metrics(tracks: set[Track], truths: set[GroundTruthPath], associator: Associator, |
| 3619 | + metric: MetricGenerator) -> dict: |
| 3620 | + """Method to produce a set of metrics between each track and truth pair |
| 3621 | +
|
| 3622 | + Parameters |
| 3623 | + ---------- |
| 3624 | + tracks : set[Track] |
| 3625 | + truths : set[GroundTruthPath] |
| 3626 | + associator : Associator |
| 3627 | + Associator used to narrow down track, truth pairs. |
| 3628 | + metric : MetricGenerator |
| 3629 | + The base metric type with which to produce metrics. |
| 3630 | +
|
| 3631 | + Returns |
| 3632 | + ------- |
| 3633 | + dict |
| 3634 | + Calculated metrics for track, truth pairs. |
| 3635 | + """ |
| 3636 | + associations, _ = associator.associated_and_unassociated_tracks(tracks, truths) |
| 3637 | + metric_generators = [] |
| 3638 | + for track in tracks: |
| 3639 | + for association in associations: |
| 3640 | + if track in association.objects: |
| 3641 | + truth = next(iter(association.objects - {track})) |
| 3642 | + metric_generators.append(metric(generator_name=(track.id, truth.id), |
| 3643 | + tracks_key=track.id, |
| 3644 | + truths_key=truth.id)) |
| 3645 | + metric_manager = MultiManager(metric_generators, associator) |
| 3646 | + metric_data = {sms.id: {sms} for sms in [*tracks, *truths]} |
| 3647 | + metric_manager.add_data(metric_data) |
| 3648 | + metrics = metric_manager.generate_metrics() |
| 3649 | + return metrics |
| 3650 | + |
| 3651 | + def get_rag_from_value(self, values: list[float]) -> int: |
| 3652 | + """Method to produce a Red-Amber-Green score based on a given set of metric values. |
| 3653 | +
|
| 3654 | + Parameters |
| 3655 | + ---------- |
| 3656 | + values : list[float] |
| 3657 | + A list of metric values for a given track |
| 3658 | +
|
| 3659 | + Returns |
| 3660 | + ------- |
| 3661 | + int |
| 3662 | + A score ranging from 1: Red to 3: Green. |
| 3663 | + """ |
| 3664 | + values = sorted(values, key=lambda x: abs(x-self.target_value)) |
| 3665 | + value = values[0] |
| 3666 | + if abs(value - self.target_value) < self.rag_boundaries.GREEN: |
| 3667 | + return 3 |
| 3668 | + if abs(value - self.target_value) < self.rag_boundaries.AMBER: |
| 3669 | + return 2 |
| 3670 | + return 1 |
| 3671 | + |
| 3672 | + def plot_tracks(self, tracks, mapping, metrics, colours=None): |
| 3673 | + if colours is None: |
| 3674 | + colours = self.colours |
| 3675 | + for track in tracks: |
| 3676 | + colour_dict = {} |
| 3677 | + metric_sets = [metric_set[self.metric_name].value |
| 3678 | + for metric_id, metric_set in metrics.items() |
| 3679 | + if track.id in metric_id] |
| 3680 | + if len(metric_sets) == 0: |
| 3681 | + colour_dict = {state.timestamp: 0 for state in track} |
| 3682 | + for metric_states in zip(*metric_sets): |
| 3683 | + values = [metric_state.value for metric_state in metric_states] |
| 3684 | + times = {metric_state.timestamp for metric_state in metric_states} |
| 3685 | + assert len(times) == 1 |
| 3686 | + time = next(iter(times)) |
| 3687 | + colour_dict[time] = self.get_rag_from_value(values) |
| 3688 | + |
| 3689 | + tracklet = Track(id=track.id) |
| 3690 | + tracklet_colour = None |
| 3691 | + for state in track.states: |
| 3692 | + tracklet.append(state) |
| 3693 | + state_colour = colour_dict[state.timestamp] |
| 3694 | + if tracklet_colour is None: |
| 3695 | + pass |
| 3696 | + elif tracklet_colour == state_colour: |
| 3697 | + pass |
| 3698 | + elif tracklet_colour != state_colour: |
| 3699 | + super().plot_tracks(tracklet, mapping, label=track.id, |
| 3700 | + marker=dict(color=colours[tracklet_colour])) |
| 3701 | + tracklet = Track(id=track.id) |
| 3702 | + tracklet.append(state) |
| 3703 | + |
| 3704 | + tracklet.append(state) |
| 3705 | + tracklet_colour = state_colour |
| 3706 | + if len(tracklet.states) > 0: |
| 3707 | + super().plot_tracks(tracklet, mapping, label=track.id, |
| 3708 | + marker=dict(color=colours[tracklet_colour])) |
0 commit comments