Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/speechlm2/s2s_duplex_stt_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def inference(cfg):
input_roles=cfg.data.input_roles,
output_roles=cfg.data.output_roles,
include_turn_metadata=True, # Enable detailed turn metadata for validation
model_cfg=model_config,
force_align_user_text=False,
early_interruption_prob=0.0,
)
datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset)

Expand Down
17 changes: 15 additions & 2 deletions examples/speechlm2/s2s_duplex_stt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def train(cfg):
with trainer.init_module():
model = DuplexSTTModel(OmegaConf.to_container(cfg, resolve=True))

dataset = DuplexS2SDataset(
train_dataset = DuplexS2SDataset(
tokenizer=model.tokenizer,
frame_length=cfg.data.frame_length,
source_sample_rate=cfg.data.source_sample_rate,
Expand All @@ -63,7 +63,20 @@ def train(cfg):
cfg=cfg.data,
model_cfg=cfg.model,
)
datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=dataset)
val_dataset = DuplexS2SDataset(
tokenizer=model.tokenizer,
frame_length=cfg.data.frame_length,
source_sample_rate=cfg.data.source_sample_rate,
target_sample_rate=cfg.data.target_sample_rate,
input_roles=cfg.data.input_roles,
output_roles=cfg.data.output_roles,
aug_by_swap_role=cfg.data.get("aug_by_swap_role", False),
cfg=cfg.data,
model_cfg=cfg.model,
force_align_user_text=False,
early_interruption_prob=0.0,
)
datamodule = DataModule(cfg.data, tokenizer=model.tokenizer, dataset=train_dataset, val_dataset=val_dataset)

trainer.fit(model, datamodule)

Expand Down
11 changes: 9 additions & 2 deletions nemo/collections/speechlm2/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ class DataModule(LightningDataModule):
The data sampling is controlled by Lhotse samplers rather than the dataset.
"""

def __init__(self, cfg, tokenizer: TokenizerSpec, dataset: torch.utils.data.Dataset) -> None:
def __init__(
self,
cfg,
tokenizer: TokenizerSpec,
dataset: torch.utils.data.Dataset,
val_dataset: torch.utils.data.Dataset = None,
) -> None:
super().__init__()
self.cfg = cfg
with open_dict(self.cfg):
Expand All @@ -68,6 +74,7 @@ def __init__(self, cfg, tokenizer: TokenizerSpec, dataset: torch.utils.data.Data
getattr(self.cfg, k).force_map_dataset = True
self.tokenizer = tokenizer
self.dataset = dataset
self.val_dataset = val_dataset if val_dataset is not None else dataset

def train_dataloader(self):
if "train_ds" not in self.cfg:
Expand Down Expand Up @@ -121,7 +128,7 @@ def _build_test_dataloader(self, cfg: DictConfig) -> torch.utils.data.DataLoader
config=cfg,
global_rank=self._get_dp_rank(),
world_size=self._get_world_size(),
dataset=self.dataset,
dataset=self.val_dataset,
tokenizer=self.tokenizer,
)

Expand Down
Loading
Loading