Skip to content

Commit a8f1b2a

Browse files
committed
Restore deleted comments and fix code wrapping format
- Restore docstrings for _resolve_recursive and _check_recursive in core.py - Fix _walk_config_fields signature: close paren on its own line - Restore explanatory comments in UpdateDiffSingerTranscriptionsCallback - Remove "Round 2" section header from FIXES_SUMMARY.md
1 parent 8054346 commit a8f1b2a

2 files changed

Lines changed: 15 additions & 4 deletions

File tree

inference/callbacks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def on_predict_batch_end(
6464
pl_module: lightning.pytorch.LightningModule,
6565
outputs: dict[str, torch.Tensor],
6666
batch: dict[str, Any],
67-
*args, **kwargs
67+
*args, **kwargs,
6868
) -> None:
6969
for i in range(batch["size"]):
7070
key = self._get_key(batch, i)
@@ -77,7 +77,7 @@ def on_predict_batch_end(
7777
self._flush_key(key, logger_fn=trainer.progress_bar_callback.print)
7878

7979
def on_predict_epoch_end(
80-
self, trainer: lightning.pytorch.Trainer, *args, **kwargs
80+
self, trainer: lightning.pytorch.Trainer, *args, **kwargs,
8181
) -> None:
8282
for key in list(self._counters):
8383
self._flush_key(key, logger_fn=trainer.progress_bar_callback.print)
@@ -293,11 +293,13 @@ def _process_item(self, key: str, batch: dict, outputs: dict, i: int) -> None:
293293
item = self.index_map[key][name]
294294
if self.use_wb:
295295
if self.uv_note_cond == "follow":
296+
# When "follow", defer v/uv to align_notes_to_words; use raw pitch for all notes.
296297
note_seq = [
297298
librosa.midi_to_note(midi, unicode=False, cents=True)
298299
for midi in note_midi
299300
]
300301
else: # "predict"
302+
# When "predict", apply presence-based "rest" now so alignment preserves them.
301303
note_seq = [
302304
librosa.midi_to_note(midi, unicode=False, cents=True) if vuv else "rest"
303305
for midi, vuv in zip(note_midi, note_vuv)
@@ -322,6 +324,7 @@ def _process_item(self, key: str, batch: dict, outputs: dict, i: int) -> None:
322324
apply_word_uv=(self.uv_note_cond == "follow"),
323325
)
324326
else:
327+
# No alignment: apply presence-based "rest" directly from model output.
325328
note_seq = [
326329
librosa.midi_to_note(score, unicode=False, cents=True) if pres else "rest"
327330
for score, pres in zip(note_midi, note_vuv)

lib/config/core.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ def validate_fields_by_scope(self) -> "ConfigBaseModel":
7373
delattr(self, name)
7474
return self
7575

76-
def _walk_config_fields(self, current: "ConfigBaseModel", context: ConfigOperationContext,
77-
leaf_fn):
76+
def _walk_config_fields(
77+
self, current: "ConfigBaseModel", context: ConfigOperationContext,
78+
leaf_fn,
79+
):
7880
for field_name, field_info in type(current).model_fields.items():
7981
field_scope = current.__field_scopes__.get(field_name)
8082
if field_scope is not None and not field_scope & context.scope:
@@ -94,21 +96,27 @@ def _walk_config_fields(self, current: "ConfigBaseModel", context: ConfigOperati
9496
context.current_path.pop()
9597

9698
def _resolve_recursive(self, current: "ConfigBaseModel", context: ConfigOperationContext):
99+
"""Recursively resolve all dynamic expressions in the config."""
100+
97101
def _resolve_leaf(current, field_name, field_info, value, context):
98102
expr = field_info.json_schema_extra.get('dynamic_expr')
99103
if expr:
100104
if isinstance(expr, ConfigOperationBase):
101105
context.current_value = value
102106
expr = expr.resolve(context)
103107
setattr(current, field_name, expr)
108+
104109
self._walk_config_fields(current, context, _resolve_leaf)
105110

106111
def _check_recursive(self, current: "ConfigBaseModel", context: ConfigOperationContext):
112+
"""Recursively check all dynamic expressions in the config."""
113+
107114
def _check_leaf(current, field_name, field_info, value, context):
108115
check = field_info.json_schema_extra.get('dynamic_check')
109116
if check:
110117
context.current_value = value
111118
check.run(context)
119+
112120
self._walk_config_fields(current, context, _check_leaf)
113121

114122
def _process_nested(self, f, scope: int = 0, path: str = None):

0 commit comments

Comments
 (0)