diff --git a/src/tplr/comms.py b/src/tplr/comms.py index ae2194bc..55d2289a 100644 --- a/src/tplr/comms.py +++ b/src/tplr/comms.py @@ -1670,10 +1670,10 @@ async def gather( current_window=window, ) - aggregated_state_dict = {} - valid_uids = [] - skipped_uids = [] # Retain UIDs that are skipped. - global_steps = [] + aggregated_state_dict: dict[str, list[torch.Tensor]] = {} + valid_uids: list[int] = [] + skipped_uids: list[int] = [] + global_steps: list[int] = [] # Ensure deterministic order across processes/ranks uids = sorted(uids) @@ -1703,204 +1703,236 @@ async def gather( f"{tplr.P(window, tplr.T() - download_start)} Downloaded peer gradients <--" ) process_start = tplr.T() - for uid, response in zip(uids, batch_responses): - received_compressed_params = set() - - if isinstance(response, Exception): - tplr.log_with_context( - level="debug", - message=f"Error from UID {uid}: {str(response)}", - current_window=window, - ) - skipped_uids.append(uid) - continue - if response is None: - tplr.logger.info(f"Skipped UID {uid} - gradient not found.") - skipped_uids.append(uid) - continue - try: - # This is where get response uses the step - response = cast(CommsGetResult, response) - state_dict_resp, global_step_resp = ( - response.data, - response.global_step, - ) - tplr.logger.debug( - f"Received state dict and global step {global_step_resp} from UID {uid}" - ) - except (TypeError, ValueError) as e: - tplr.log_with_context( - level="debug", - message=f"Invalid response from UID {uid}: {e}", - current_window=window, - ) - skipped_uids.append(uid) - continue - - if state_dict_resp is None: - tplr.logger.debug(f"Empty state dict from UID {uid}") - skipped_uids.append(uid) - continue + # We don't need gradients for any of this + with torch.no_grad(): + for uid, response in zip(uids, batch_responses): + received_compressed_params: set[str] = set() + + # ----------------- error / None handling ----------------- + if isinstance(response, Exception): + tplr.log_with_context( + level="debug", + message=f"Error from UID {uid}: {str(response)}", + current_window=window, + ) + skipped_uids.append(uid) + continue - # ---------- Begin Compressed Indices and Values Check ---------- - valid_response = True - for param_name, tensor in state_dict_resp.items(): - received_compressed_params.add(param_name) - - # ---------------------------------------------------------- - # (1) Validate quantisation parameters themselves - # ---------------------------------------------------------- - if param_name.endswith("quant_params"): - shift, scale, offset, lookup, dtype = tensor - if ( - (not torch.isfinite(shift)) - or isinstance(scale, float) - and ( - not math.isfinite(scale) - or abs(scale) < 1e-12 - or abs(scale) > 1e4 - ) - ): - tplr.logger.warning( - f"Bad quant‑params in {param_name} from UID {uid}; " - f"shift={shift}, scale={scale}" - ) - valid_response = False - break - if torch.is_tensor(lookup) and ( - not torch.isfinite(lookup).all() - ): - tplr.logger.warning( - f"Lookup table contains non‑finite values in {param_name} " - f"from UID {uid}" - ) - valid_response = False - break + if response is None: + tplr.logger.info(f"Skipped UID {uid} - gradient not found.") + skipped_uids.append(uid) + continue - if param_name.endswith("idxs"): - base_name = param_name[:-4] - totalk_value = totalks.get(base_name) - if totalk_value is None: - tplr.logger.warning( - f"Missing totalk for parameter {base_name} from UID {uid}, skipping UID." - ) - valid_response = False - break - # totalks stores integers, not tensors - totalk = ( - totalk_value - if isinstance(totalk_value, int) - else totalk_value.numel() + try: + response = cast(CommsGetResult, response) + state_dict_resp, global_step_resp = ( + response.data, + response.global_step, ) - # Get corresponding vals tensor for 12-bit unpacking - vals_tensor = state_dict_resp.get(base_name + "vals", None) - try: - self.check_compressed_indices( - param_name, - tensor, - totalk, - allowed_topk=self.hparams.topk_compression, - vals=vals_tensor, - ) - except Exception as e: - tplr.logger.warning( - f"Compressed indices check failed for parameter {param_name} from UID {uid}: {e}" - ) - valid_response = False - break - # Check if values are valid (not NaN, not Inf) - validate without dequantizing - elif param_name.endswith("vals"): - # Only move to device for validation if needed - if tensor.dtype == torch.uint8: - # For quantized values, do a quick check on the raw bytes - if tensor.nelement() == 0: + tplr.logger.debug( + f"Received state dict and global step {global_step_resp} from UID {uid}" + ) + except (TypeError, ValueError) as e: + tplr.log_with_context( + level="debug", + message=f"Invalid response from UID {uid}: {e}", + current_window=window, + ) + skipped_uids.append(uid) + continue + + if state_dict_resp is None: + tplr.logger.debug(f"Empty state dict from UID {uid}") + skipped_uids.append(uid) + continue + + # ---------- Begin Compressed Indices and Values Check ---------- + valid_response = True + + for param_name, tensor in state_dict_resp.items(): + received_compressed_params.add(param_name) + + # (1) Validate quantisation parameters themselves + if param_name.endswith("quant_params"): + shift, scale, offset, lookup, dtype = tensor + if ( + (not torch.isfinite(shift)) + or ( + isinstance(scale, float) + and ( + not math.isfinite(scale) + or abs(scale) < 1e-12 + or abs(scale) > 1e4 + ) + ) + or ( + torch.is_tensor(scale) + and ( + not torch.isfinite(scale).all() + or scale.abs().max() < 1e-12 + or scale.abs().max() > 1e4 + ) + ) + ): tplr.logger.warning( - f"Empty tensor in {param_name} from UID {uid}, skipping" + f"Bad quant-params in {param_name} from UID {uid}; " + f"shift={shift}, scale={scale}" ) valid_response = False break - else: - # For non-quantized tensors, check for NaN/Inf - tensor_to_check = tensor.to(device) - if ( - torch.isnan(tensor_to_check).any() - or torch.isinf(tensor_to_check).any() + if torch.is_tensor(lookup) and ( + not torch.isfinite(lookup).all() ): tplr.logger.warning( - f"NaN/Inf in {param_name} from UID {uid}, skipping" + f"Lookup table contains non-finite values in {param_name} " + f"from UID {uid}" ) valid_response = False break - # Clean up temporary tensor - del tensor_to_check - - # ------------------------------------------------------ - # (2) Only validate quantization params exist, don't dequantize - # ------------------------------------------------------ - qparams = state_dict_resp.get( - param_name[:-4] + "quant_params", None - ) - if qparams is None and tensor.dtype == torch.uint8: - tplr.logger.warning( - f"Missing quant_params for quantized {param_name} from UID {uid}" + + if param_name.endswith("idxs"): + base_name = param_name[:-4] + totalk_value = totalks.get(base_name) + if totalk_value is None: + tplr.logger.warning( + f"Missing totalk for parameter {base_name} from UID {uid}, skipping UID." + ) + valid_response = False + break + + # totalks stores integers, not tensors + totalk = ( + totalk_value + if isinstance(totalk_value, int) + else totalk_value.numel() ) - valid_response = False - break - missing_params = ( - expected_compressed_params - received_compressed_params - ) - if missing_params: - tplr.logger.warning( - f"UID {uid} missing compressed parameters: {missing_params}, skipping UID." - ) - valid_response = False + vals_tensor = state_dict_resp.get( + base_name + "vals", None + ) + try: + self.check_compressed_indices( + param_name, + tensor, + totalk, + allowed_topk=self.hparams.topk_compression, + vals=vals_tensor, + ) + except Exception as e: + tplr.logger.warning( + f"Compressed indices check failed for parameter {param_name} from UID {uid}: {e}" + ) + valid_response = False + break - # If any check failed, skip this UID entirely - if not valid_response: - tplr.logger.info( - f"Skipping UID {uid} due to validation failures" + # Check if values are valid (not NaN, not Inf) - validate without dequantizing + elif param_name.endswith("vals"): + if tensor.dtype == torch.uint8: + # For quantized values, do a quick check on the raw bytes + if tensor.nelement() == 0: + tplr.logger.warning( + f"Empty tensor in {param_name} from UID {uid}, skipping" + ) + valid_response = False + break + else: + # For non-quantized tensors, check for NaN/Inf + # Avoid unnecessary copies: only move if needed + target_device = torch.device(device) + + if tensor.device.type == target_device: + tensor_to_check = tensor + needs_delete = False + else: + tensor_to_check = tensor.to( + target_device, non_blocking=True + ) + needs_delete = True + + if ( + torch.isnan(tensor_to_check).any() + or torch.isinf(tensor_to_check).any() + ): + tplr.logger.warning( + f"NaN/Inf in {param_name} from UID {uid}, skipping" + ) + valid_response = False + + if needs_delete: + del tensor_to_check + + if not valid_response: + break + + # (2) Only validate quantization params exist, don't dequantize + qparams = state_dict_resp.get( + param_name[:-4] + "quant_params", None + ) + if qparams is None and tensor.dtype == torch.uint8: + tplr.logger.warning( + f"Missing quant_params for quantized {param_name} from UID {uid}" + ) + valid_response = False + break + + missing_params = ( + expected_compressed_params - received_compressed_params ) - skipped_uids.append(uid) - continue - # ---------- End Compressed Indices and Values Check ---------- - - # Process tensors - keep everything quantized to save memory - for param_name, tensor in state_dict_resp.items(): - # 1️⃣ Indices are kept as‑is ----------------------------------------- - if param_name.endswith("idxs"): - aggregated_state_dict.setdefault(param_name, []).append( - tensor - ) - # Handle 12-bit packed format (uint8 tensor) - metrics["download_bytes"] += ( - tensor.element_size() * tensor.nelement() + if missing_params: + tplr.logger.warning( + f"UID {uid} missing compressed parameters: {missing_params}, skipping UID." ) + valid_response = False - # 2️⃣ Values → keep quantized, store with quant_params --------------- - elif param_name.endswith("vals"): - # Keep values quantized - just store the raw tensor - aggregated_state_dict.setdefault(param_name, []).append( - tensor # Keep original dtype (uint8 if quantized) - ) - metrics["download_bytes"] += ( - tensor.element_size() * tensor.nelement() + if not valid_response: + tplr.logger.info( + f"Skipping UID {uid} due to validation failures" ) + skipped_uids.append(uid) + # Drop references early for this UID + del state_dict_resp + del response + continue + # ---------- End Compressed Indices and Values Check ---------- + + # ----------------- Aggregation (still under no_grad) ----------------- + for param_name, tensor in state_dict_resp.items(): + if param_name.endswith("idxs"): + aggregated_state_dict.setdefault(param_name, []).append( + tensor + ) + metrics["download_bytes"] += ( + tensor.element_size() * tensor.nelement() + ) - # 3️⃣ Store quantization parameters for later use -------------------- - elif param_name.endswith("quant_params"): - aggregated_state_dict.setdefault(param_name, []).append( - tensor - ) + elif param_name.endswith("vals"): + aggregated_state_dict.setdefault(param_name, []).append( + tensor + ) + metrics["download_bytes"] += ( + tensor.element_size() * tensor.nelement() + ) + + elif param_name.endswith("quant_params"): + aggregated_state_dict.setdefault(param_name, []).append( + tensor + ) - valid_uids.append(uid) - global_steps.append(global_step_resp) + valid_uids.append(uid) + global_steps.append(global_step_resp) + + # Explicitly drop UID-local references now that we've copied + del state_dict_resp + del response tplr.logger.info( f"{tplr.P(window, tplr.T() - process_start)} Processed peer gradients <--" ) + # Drop the entire batch_responses list to free references + del batch_responses + except Exception as e: tplr.logger.error(f"Error processing uid batch: {str(e)}") diff --git a/src/tplr/dcp_checkpoint.py b/src/tplr/dcp_checkpoint.py index 682877a3..cdcef161 100644 --- a/src/tplr/dcp_checkpoint.py +++ b/src/tplr/dcp_checkpoint.py @@ -753,6 +753,10 @@ async def download_distributed( local_dir = self._local_dir(layout) world, r = _world(), _rank() + + # ensure all ranks agreed on window and created local_dir + _barrier(process_group) + # Try highest-staked bucket first, then own bucket = await self._choose_read_bucket( prefer_highest_staked=prefer_highest_staked @@ -871,7 +875,9 @@ async def download_and_load( ) -> tuple[int, int] | None: local_dir = await ( self.download_distributed( - window=window, prefer_highest_staked=prefer_highest_staked + window=window, + prefer_highest_staked=prefer_highest_staked, + process_group=process_group, # ensure barriers use the same PG as load() ) if shared_fs else self.download_all( @@ -880,10 +886,61 @@ async def download_and_load( ) if local_dir is None: return None - sidecar = json.loads((local_dir / "extra_metadata.json").read_text()) + + # make sure *all* ranks have finished downloading / mop-up + _barrier(process_group) + + sidecar_path = local_dir / "extra_metadata.json" + + # retry on FileNotFoundError / transient IO errors + async def _read_sidecar(): + return await asyncio.to_thread(sidecar_path.read_text) + + try: + sidecar_text = await _retry_n( + _read_sidecar, + desc=f"read sidecar {sidecar_path}", + attempts=5, + delay_s=0.5, + ) + read_ok = True + except Exception as e: + tplr.logger.error( + f"[DCP][download-and-load] failed to read sidecar at " + f"{_safe(sidecar_path)}: {_safe(e)}" + ) + read_ok = False + sidecar_text = "" + + # --- Synchronize success/failure across ranks --- + if dist.is_available() and dist.is_initialized(): + ok_tensor = torch.tensor([1 if read_ok else 0], dtype=torch.int32) + # always keep on CPU for safety + dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN) + read_ok = bool(ok_tensor.item()) + + if not read_ok: + # All ranks return consistently + return None + + try: + sidecar = json.loads(sidecar_text) + except Exception as e: + tplr.logger.error( + f"[DCP][download-load] corrupted sidecar JSON at " + f"{_safe(sidecar_path)}: {_safe(e)}" + ) + return None w = int(sidecar["window"]) global_step = int(sidecar.get("global_step", -1)) + + # Barrier before DCP load (safety belt) + _barrier(process_group) + self.load_local(model=model, window=w, process_group=process_group) + + # Barrier after load to ensure all ranks finish loading + _barrier(process_group) return w, global_step async def check_checkpoint_exists( diff --git a/tests/conftest.py b/tests/conftest.py index 4ef73b47..9729d847 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,28 +9,28 @@ def pytest_configure(config): import os # Mock R2 bucket access for testing -os.environ.setdefault("R2_AGGREGATOR_ACCOUNT_ID", "mock-account-id") -os.environ.setdefault("R2_AGGREGATOR_BUCKET_NAME", "mock-bucket-name") -os.environ.setdefault("R2_AGGREGATOR_READ_ACCESS_KEY_ID", "mock-read-key-id") -os.environ.setdefault("R2_AGGREGATOR_READ_SECRET_ACCESS_KEY", "mock-read-secret-key") +# Use direct assignment to override empty strings from CI +os.environ["R2_AGGREGATOR_ACCOUNT_ID"] = "mock-account-id" +os.environ["R2_AGGREGATOR_BUCKET_NAME"] = "mock-bucket-name" +os.environ["R2_AGGREGATOR_READ_ACCESS_KEY_ID"] = "mock-read-key-id" +os.environ["R2_AGGREGATOR_READ_SECRET_ACCESS_KEY"] = "mock-read-secret-key" +os.environ["R2_AGGREGATOR_WRITE_ACCESS_KEY_ID"] = "mock-write-key-id" +os.environ["R2_AGGREGATOR_WRITE_SECRET_ACCESS_KEY"] = "mock-write-secret-key" # Also set other required variables from config.py -os.environ.setdefault("R2_GRADIENTS_ACCOUNT_ID", "mock-gradients-account-id") -os.environ.setdefault("R2_GRADIENTS_BUCKET_NAME", "mock-gradients-bucket-name") -os.environ.setdefault("R2_GRADIENTS_READ_ACCESS_KEY_ID", "mock-gradients-read-key-id") -os.environ.setdefault( - "R2_GRADIENTS_READ_SECRET_ACCESS_KEY", "mock-gradients-read-secret-key" -) -os.environ.setdefault("R2_GRADIENTS_WRITE_ACCESS_KEY_ID", "mock-gradients-write-key-id") -os.environ.setdefault( - "R2_GRADIENTS_WRITE_SECRET_ACCESS_KEY", "mock-gradients-write-secret-key" -) -os.environ.setdefault("R2_DATASET_ACCOUNT_ID", "mock-dataset-account-id") -os.environ.setdefault("R2_DATASET_BUCKET_NAME", "mock-dataset-bucket-name") -os.environ.setdefault("R2_DATASET_READ_ACCESS_KEY_ID", "mock-dataset-read-key-id") -os.environ.setdefault( - "R2_DATASET_READ_SECRET_ACCESS_KEY", "mock-dataset-read-secret-key" -) +os.environ["R2_GRADIENTS_ACCOUNT_ID"] = "mock-gradients-account-id" +os.environ["R2_GRADIENTS_BUCKET_NAME"] = "mock-gradients-bucket-name" +os.environ["R2_GRADIENTS_READ_ACCESS_KEY_ID"] = "mock-gradients-read-key-id" +os.environ["R2_GRADIENTS_READ_SECRET_ACCESS_KEY"] = "mock-gradients-read-secret-key" +os.environ["R2_GRADIENTS_WRITE_ACCESS_KEY_ID"] = "mock-gradients-write-key-id" +os.environ["R2_GRADIENTS_WRITE_SECRET_ACCESS_KEY"] = "mock-gradients-write-secret-key" +os.environ["R2_DATASET_ACCOUNT_ID"] = "mock-dataset-account-id" +os.environ["R2_DATASET_BUCKET_NAME"] = "mock-dataset-bucket-name" +os.environ["R2_DATASET_READ_ACCESS_KEY_ID"] = "mock-dataset-read-key-id" +os.environ["R2_DATASET_READ_SECRET_ACCESS_KEY"] = "mock-dataset-read-secret-key" +os.environ["R2_DATASET_WRITE_ACCESS_KEY_ID"] = "mock-dataset-write-key-id" +os.environ["R2_DATASET_WRITE_SECRET_ACCESS_KEY"] = "mock-dataset-write-secret-key" +os.environ["DATASET_BINS_PATH"] = "/mock/dataset/bins/path" import pytest from unittest.mock import Mock, patch