Skip to content

Commit 804ed17

Browse files
committed
Add docstrings, type hints, and modernize code
1 parent c016c28 commit 804ed17

29 files changed

Lines changed: 4301 additions & 1727 deletions

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
>=3.10
1+
>=3.10, <3.15

histogram.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def gtk_run(closure):
6969
Args:
7070
closure Ignored parameter
7171
"""
72-
tools.warning("GTK 3.0 is unavailable: %s" % (err,))
72+
tools.warning(f"GTK 3.0 is unavailable: {err}")
7373

7474
# ---------------------------------------------------------------------------- #
7575
# Data frame columns selection helper
@@ -91,7 +91,7 @@ def select(data, *only_columns):
9191
if len(only_columns) == 0:
9292
return data
9393
# Intelligent selection
94-
columns = list()
94+
columns = []
9595
for only_column in only_columns:
9696
only_column = only_column.lower()
9797
for column in data.columns:
@@ -136,7 +136,7 @@ def to_string(x):
136136
Converted data to string
137137
"""
138138
if type(x) is float:
139-
return "%e" % x
139+
return f"{x:e}"
140140
return str(x).strip()
141141

142142
def __init__(self, data, title="Display data"):
@@ -149,7 +149,7 @@ def __init__(self, data, title="Display data"):
149149
# Make and fill list store
150150
store = Gtk.ListStore(*([str] * (len(data.columns) + 1)))
151151
for row in data.itertuples():
152-
store.append(list(self.to_string(x) for x in row))
152+
store.append([self.to_string(x) for x in row])
153153
# Make the associated tree view
154154
view = Gtk.TreeView(store)
155155
columns = list(data.columns)
@@ -196,17 +196,15 @@ def __init__(self, path_results):
196196
# Ensure directory exist
197197
if not path_results.exists():
198198
raise tools.UserException(
199-
"Result directory %r cannot be accessed or does not exist"
200-
% str(path_results)
199+
f"Result directory {str(path_results)!r} cannot be accessed or does not exist"
201200
)
202201
# Load configuration string
203202
path_config = path_results / "config"
204203
try:
205204
data_config = path_config.read_text().strip()
206205
except Exception as err:
207206
tools.warning(
208-
"Result directory %r: unable to read configuration (%s)"
209-
% (str(path_results), err)
207+
f"Result directory {str(path_results)!r}: unable to read configuration ({err})"
210208
)
211209
data_config = None
212210
# Load configuration json
@@ -216,8 +214,7 @@ def __init__(self, path_results):
216214
data_json = json.load(fd)
217215
except Exception as err:
218216
tools.warning(
219-
"Result directory %r: unable to read JSON configuration (%s)"
220-
% (str(path_results), err)
217+
f"Result directory {str(path_results)!r}: unable to read JSON configuration ({err})"
221218
)
222219
data_json = None
223220
# Load training data
@@ -229,8 +226,7 @@ def __init__(self, path_results):
229226
data_study.index.name = "Step number"
230227
except Exception as err:
231228
tools.warning(
232-
"Result directory %r: unable to read training data (%s)"
233-
% (str(path_results), err)
229+
f"Result directory {str(path_results)!r}: unable to read training data ({err})"
234230
)
235231
data_study = None
236232
# Load evaluation data
@@ -240,8 +236,7 @@ def __init__(self, path_results):
240236
data_eval.index.name = "Step number"
241237
except Exception as err:
242238
tools.warning(
243-
"Result directory %r: unable to read evaluation data (%s)"
244-
% (str(path_results), err)
239+
f"Result directory {str(path_results)!r}: unable to read evaluation data ({err})"
245240
)
246241
data_eval = None
247242
# Merge data frames
@@ -284,8 +279,7 @@ def display(self, *only_columns, name=None):
284279
display(
285280
self.get(*only_columns),
286281
title=(
287-
"Session data%s for %r"
288-
% (" (subset)" if len(only_columns) > 0 else "", self.name)
282+
"Session data{} for {!r}".format(" (subset)" if len(only_columns) > 0 else "", self.name)
289283
),
290284
)
291285
# Return self to enable chaining
@@ -343,7 +337,7 @@ def compute_epoch(self):
343337
}.get(dataset_name)
344338
if training_size is None:
345339
tools.warning(
346-
"Unknown dataset %r, cannot compute the epoch number" % dataset_name
340+
f"Unknown dataset {dataset_name!r}, cannot compute the epoch number"
347341
)
348342
return self
349343
self.data[column_name] = self.data["Training point count"] / training_size
@@ -465,17 +459,15 @@ def include(self, data, *cols, errs=None, lalp=1.0, ccnt=None):
465459
data = data.data
466460
elif not isinstance(data, pandas.DataFrame):
467461
raise RuntimeError(
468-
"Expected a Session or DataFrame for 'data', got a %r"
469-
% tools.fullqual(type(data))
462+
f"Expected a Session or DataFrame for 'data', got a {tools.fullqual(type(data))!r}"
470463
)
471464
# Get the x-axis values
472465
if self._idx is None:
473466
x = data.index.to_numpy()
474467
else:
475468
if self._idx not in data:
476469
raise RuntimeError(
477-
"No column named %r to use as index in the given session/dataframe"
478-
% (self._idx,)
470+
f"No column named {self._idx!r} to use as index in the given session/dataframe"
479471
)
480472
x = data[self._idx].to_numpy()
481473
# Select semantic: empty list = select all
@@ -541,17 +533,15 @@ def include_single(self, data, key, col, err=None, lalp=1.0, ccnt=None):
541533
data = data.data
542534
elif not isinstance(data, pandas.DataFrame):
543535
raise RuntimeError(
544-
"Expected a Session or DataFrame for 'data', got a %r"
545-
% tools.fullqual(type(data))
536+
f"Expected a Session or DataFrame for 'data', got a {tools.fullqual(type(data))!r}"
546537
)
547538
# Get the x-axis values
548539
if self._idx is None:
549540
x = data.index.to_numpy()
550541
else:
551542
if self._idx not in data:
552543
raise RuntimeError(
553-
"No column named %r to use as index in the given session/dataframe"
554-
% (self._idx,)
544+
f"No column named {self._idx!r} to use as index in the given session/dataframe"
555545
)
556546
x = data[self._idx].to_numpy()
557547
# Pick a new line style and color
@@ -641,15 +631,13 @@ def generator_sum(gen):
641631
if zlabel is not None:
642632
if self._tax is None:
643633
tools.warning(
644-
"No secondary y-axis found, but its label %r was provided"
645-
% (zlabel,)
634+
f"No secondary y-axis found, but its label {zlabel!r} was provided"
646635
)
647636
else:
648637
self._tax.set_ylabel(zlabel)
649638
elif self._tax is not None:
650639
tools.warning(
651-
"No label provided for the secondary y-axis; using label %r from the primary"
652-
% (ylabel,)
640+
f"No label provided for the secondary y-axis; using label {ylabel!r} from the primary"
653641
)
654642
self._tax.set_ylabel(ylabel)
655643
self._ax.set_xlim(left=xmin, right=xmax)

krum/aggregators/__init__.py

Lines changed: 95 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,51 +10,91 @@
1010
# @section DESCRIPTION
1111
#
1212
# Loading of the local modules.
13-
#
14-
# Each rule MUST support taking any named arguments, possibly ignoring them.
15-
# The parameters MUST all be passed as their keyword arguments.
16-
# The reserved argument names, and their interface, are the following:
17-
# · gradients: Non-empty list of gradients to aggregate
18-
# · f : Number of Byzantine gradients to support
19-
# · model : Model (duck-typing 'experiments.Model') with valid default dataset and loss set
20-
# The rule, given "valid" parameter(s), MUST NOT return a tensor that is a reference to any tensor given as parameter.
21-
#
22-
# Each rule MUST provide a "check" member function, taking the same arguments as the rule itself.
23-
# The "check" member function returns 'None' when the parameters are valid,
24-
# or an explanatory string when the parameters are not valid.
25-
# The check member function MUST NOT modify the given parameters.
26-
#
27-
# Once registered, the check member function will be available as member "check".
28-
# The raw function and a wrapped checking the input/output of the raw function
29-
# will respectively be available as members "unchecked" and "checked".
30-
# Which of these two functions is called by default depends whether debug mode is enabled.
3113
###
3214

15+
"""
16+
Gradient aggregation rules (GARs) for Byzantine-resilient distributed learning.
17+
18+
Each rule combines a keyword-only aggregation function with a validation
19+
function and optional metadata used by the training and experiment scripts.
20+
21+
Contract
22+
--------
23+
24+
Each aggregation rule MUST:
25+
26+
1. Accept keyword-only arguments
27+
2. Accept the reserved parameter ``gradients`` (non-empty list of gradients)
28+
3. Accept the reserved parameter ``f`` (number of Byzantine gradients to tolerate)
29+
4. Accept the reserved parameter ``model`` (model with configured defaults)
30+
5. NOT return a tensor that aliases any input tensor
31+
32+
Each rule MUST provide a ``check`` function that validates parameters and
33+
returns ``None`` when valid, or a user-facing error message otherwise.
34+
35+
The module exposes three variants for each rule:
36+
37+
- ``rule``: The default version (checked in debug mode, unchecked in release)
38+
- ``rule.checked``: Always validates parameters
39+
- ``rule.unchecked``: Skips validation (faster in production)
40+
41+
Additional metadata available on each rule:
42+
43+
- ``rule.check``: The validation function
44+
- ``rule.upper_bound``: Theoretical bound on stddev/norm ratio (if available)
45+
- ``rule.influence``: Attack acceptance ratio (if available)
46+
"""
47+
3348
import pathlib
49+
from collections.abc import Callable
3450

35-
from krum import tools
51+
import tools
52+
import torch
3653

3754
# ---------------------------------------------------------------------------- #
3855
# Automated GAR loader
3956

4057

41-
def make_gar(unchecked, check, upper_bound=None, influence=None):
42-
"""GAR wrapper helper.
43-
Args:
44-
unchecked Associated function (see module description)
45-
check Parameter validity check function
46-
upper_bound Compute the theoretical upper bound on the ratio non-Byzantine standard deviation / norm to use this aggregation rule: (n, f, d) -> float
47-
influence Attack acceptation ratio function
48-
Returns:
49-
Wrapped GAR
58+
def make_gar(
59+
unchecked: Callable,
60+
check: Callable,
61+
upper_bound: Callable | None = None,
62+
influence: Callable | None = None,
63+
) -> Callable:
64+
"""
65+
Wrap an unchecked GAR with validation and metadata.
66+
67+
Parameters
68+
----------
69+
unchecked : callable
70+
Aggregation function implementing the rule without parameter checks.
71+
check : callable
72+
Validation function. It must return ``None`` when parameters are valid,
73+
or an error message otherwise.
74+
upper_bound : callable, optional
75+
Function computing the theoretical upper bound on the ratio between
76+
non-Byzantine standard deviation and gradient norm. The expected
77+
signature is ``(n, f, d) -> float``.
78+
influence : callable, optional
79+
Function computing the accepted Byzantine-gradient ratio for a given
80+
set of honest and attack gradients.
81+
82+
Returns
83+
-------
84+
callable
85+
Checked or unchecked GAR selected according to ``__debug__``. The
86+
returned callable is annotated with ``check``, ``checked``,
87+
``unchecked``, ``upper_bound``, and ``influence`` attributes.
5088
"""
5189

5290
# Closure wrapping the call with checks
5391
def checked(**kwargs):
5492
# Check parameter validity
5593
message = check(**kwargs)
5694
if message is not None:
57-
raise tools.UserException(f"Aggregation rule {name!r} cannot be used with the given parameters: {message}")
95+
raise tools.UserException(
96+
f"Aggregation rule {name!r} cannot be used with the given parameters: {message}"
97+
)
5898
# Aggregation (hard to assert return value, duck-typing is allowed...)
5999
return unchecked(**kwargs)
60100

@@ -70,26 +110,42 @@ def checked(**kwargs):
70110
return func
71111

72112

73-
def register(name, unchecked, check, upper_bound=None, influence=None):
74-
"""Simple registration-wrapper helper.
75-
Args:
76-
name GAR name
77-
unchecked Associated function (see module description)
78-
check Parameter validity check function
79-
upper_bound Compute the theoretical upper bound on the ratio non-Byzantine standard deviation / norm to use this aggregation rule: (n, f, d) -> float
80-
influence Attack acceptation ratio function
113+
def register(
114+
name: str,
115+
unchecked: Callable,
116+
check: Callable,
117+
upper_bound: Callable | None = None,
118+
influence: Callable | None = None,
119+
) -> None:
120+
"""
121+
Register a gradient aggregation rule.
122+
123+
Parameters
124+
----------
125+
name : str
126+
User-visible GAR name.
127+
unchecked : callable
128+
Aggregation function implementing the rule without parameter checks.
129+
check : callable
130+
Validation function associated with ``unchecked``.
131+
upper_bound : callable, optional
132+
Function computing the rule's theoretical upper bound.
133+
influence : callable, optional
134+
Function computing the accepted Byzantine-gradient ratio.
81135
"""
82136
global gars
83137
# Check if name already in use
84138
if name in gars:
85139
tools.warning(f"Unable to register {name!r} GAR: name already in use")
86140
return
87141
# Export the selected function with the associated name
88-
gars[name] = make_gar(unchecked, check, upper_bound=upper_bound, influence=influence)
142+
gars[name] = make_gar(
143+
unchecked, check, upper_bound=upper_bound, influence=influence
144+
)
89145

90146

91147
# Registered rules (mapping name -> aggregation rule)
92-
gars = dict()
148+
gars = {}
93149

94150
# Load all local modules
95151
with tools.Context("aggregators", None):

0 commit comments

Comments
 (0)