Skip to content

Commit 90457ca

Browse files
authored
Merge pull request #171 from microsoft/typo-fixes
Collection of small typo fixes
2 parents 2354d8c + cb653e9 commit 90457ca

File tree

8 files changed

+19
-15
lines changed

8 files changed

+19
-15
lines changed

aurora/model/aurora.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def __init__(
124124
drop_rate (float, optional): Drop-out rate.
125125
drop_path (float, optional): Drop-path rate.
126126
enc_depth (int, optional): Number of Perceiver blocks in the encoder.
127-
dec_depth (int, optioanl): Number of Perceiver blocks in the decoder.
127+
dec_depth (int, optional): Number of Perceiver blocks in the decoder.
128128
dec_mlp_ratio (float, optional): Hidden dim. to embedding dim. ratio for MLPs in the
129129
decoder. The embedding dimensionality here is different, which is why this is a
130130
separate parameter.
@@ -266,7 +266,7 @@ def forward(self, batch: Batch) -> Batch:
266266
"""Forward pass.
267267
268268
Args:
269-
batch (:class:`Batch`): Batch to run the model on.
269+
batch (:class:`aurora.Batch`): Batch to run the model on.
270270
271271
Returns:
272272
:class:`Batch`: Prediction for the batch.
@@ -472,7 +472,7 @@ def adapt_checkpoint_max_history_size(self, checkpoint: dict[str, torch.Tensor])
472472
473473
If a checkpoint was trained with a larger `max_history_size` than the current model,
474474
this function will assert fail to prevent loading the checkpoint. This is to
475-
prevent loading a checkpoint which will likely cause the checkpoint to degrade is
475+
prevent loading a checkpoint which will likely cause the checkpoint to degrade its
476476
performance.
477477
478478
This implementation copies weights from the checkpoint to the model and fills zeros

aurora/model/decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def forward(
176176
177177
Args:
178178
x (torch.Tensor): Backbone output of shape `(B, L, D)`.
179-
batch (:class:`aurora.batch.Batch`): Batch to make predictions for.
179+
batch (:class:`aurora.Batch`): Batch to make predictions for.
180180
patch_res (tuple[int, int, int]): Patch resolution
181181
lead_time (timedelta): Lead time.
182182

aurora/model/encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def forward(self, batch: Batch, lead_time: timedelta) -> torch.Tensor:
199199
"""Peform encoding.
200200
201201
Args:
202-
batch (:class:`.Batch`): Batch to encode.
202+
batch (:class:`aurora.Batch`): Batch to encode.
203203
lead_time (timedelta): Lead time.
204204
205205
Returns:

aurora/model/lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(
2121
r: int = 4,
2222
alpha: int = 1,
2323
dropout: float = 0.0,
24-
):
24+
) -> None:
2525
"""Initialise.
2626
2727
Args:
@@ -75,7 +75,7 @@ def __init__(
7575
dropout: float = 0.0,
7676
max_steps: int = 40,
7777
mode: LoRAMode = "single",
78-
):
78+
) -> None:
7979
"""Initialise.
8080
8181
Args:

aurora/model/swin3d.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def forward(
146146
mask (torch.Tensor, optional): Attention mask of floating points in the range
147147
`[-inf, 0)` with shape of `(nW, ws, ws)`, where `nW` is the number of windows,
148148
and `ws` is the window size (i.e. total tokens inside the window).
149+
rollout_step (int, optional): Roll-out step. Defaults to `0`.
149150
150151
Returns:
151152
torch.Tensor: Output of shape `(nW*B, N, C)`.
@@ -198,8 +199,8 @@ def window_partition_3d(x: torch.Tensor, ws: tuple[int, int, int]) -> torch.Tens
198199
"""Partition into windows.
199200
200201
Args:
201-
x: (torch.Tensor): Input tensor of shape `(B, C, H, W, D)`.
202-
ws: (tuple[int, int, int]): A 3D window size `(Wc, Wh, Ww)`.
202+
x (torch.Tensor): Input tensor of shape `(B, C, H, W, D)`.
203+
ws (tuple[int, int, int]): A 3D window size `(Wc, Wh, Ww)`.
203204
204205
Returns:
205206
torch.Tensor: Partitioning of shape `(num_windows*B, Wc, Wh, Ww, D)`.
@@ -318,7 +319,8 @@ def compute_3d_shifted_window_mask(
318319
H (int): Height of the image.
319320
W (int): Width of the image.
320321
ws (tuple[int, int, int]): Window sizes of the form `(Wc, Wh, Ww)`.
321-
ss (tuple[int, int, int]): Shift sizes of the form `(Sc, Sh, Sw)`
322+
ss (tuple[int, int, int]): Shift sizes of the form `(Sc, Sh, Sw)`.
323+
device (torch.device): Device of the mask.
322324
dtype (torch.dtype, optional): Data type of the mask. Defaults to `torch.bfloat16`.
323325
warped (bool): If `True`,assume that the left and right sides of the image are connected.
324326
Defaults to `True`.
@@ -768,7 +770,8 @@ def __init__(
768770
lora_mode: LoRAMode = "single",
769771
use_lora: bool = False,
770772
) -> None:
771-
"""
773+
"""Initialise.
774+
772775
Args:
773776
embed_dim (int): Patch embedding dimension. Default to `96`.
774777
encoder_depths (tuple[int, ...]): Number of blocks in each encoder layer. Defaults to

aurora/model/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def unpatchify(x: torch.Tensor, V: int, H: int, W: int, P: int) -> torch.Tensor:
2424
V (int): Number of variables.
2525
H (int): Number of latitudes.
2626
W (int): Number of longitudes.
27+
P (int): Patch size.
2728
2829
Returns:
2930
torch.Tensor: Unpatchified representation of shape `(B, V, C, H, W)`.

aurora/rollout.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ def rollout(model: Aurora, batch: Batch, steps: int) -> Generator[Batch, None, N
1515
"""Perform a roll-out to make long-term predictions.
1616
1717
Args:
18-
model (:class:`aurora.model.aurora.Aurora`): The model to roll out.
19-
batch (:class:`aurora.batch.Batch`): The batch to start the roll-out from.
18+
model (:class:`aurora.Aurora`): The model to roll out.
19+
batch (:class:`aurora.Batch`): The batch to start the roll-out from.
2020
steps (int): The number of roll-out steps.
2121
2222
Yields:
23-
:class:`aurora.batch.Batch`: The prediction after every step.
23+
:class:`aurora.Batch`: The prediction after every step.
2424
"""
2525
# We will need to concatenate data, so ensure that everything is already of the right form.
2626
batch = model.batch_transform_hook(batch) # This might modify the available variables.

aurora/tracker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def step(self, batch: Batch) -> None:
156156
"""Track the next step.
157157
158158
Args:
159-
batch (:class:`aurora.batch.Batch`): Prediction.
159+
batch (:class:`aurora.Batch`): Prediction.
160160
"""
161161
# Check that there is only one prediction. We don't support batched tracking.
162162
if len(batch.metadata.time) != 1:

0 commit comments

Comments
 (0)