diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 19be67d..2a148bb 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -725,7 +725,12 @@ def gather_metas(self, checkpoint_name: str): if not self._global_device_uuids: global_device_uuids.append(metas_buckets.device_uuid) if metas_buckets.memory_buffer_metas_list: - self._current_global_parameter_metas[i] = metas_buckets + # _current_global_parameter_metas value should be MemoryBufferMetaList, but metas_buckets is DataToGather + # so we need to convert it to MemoryBufferMetaList + self._current_global_parameter_metas[i] = MemoryBufferMetaList( + memory_buffer_metas_list=metas_buckets.memory_buffer_metas_list, + p2p_store_addr=metas_buckets.p2p_store_addr, + ) num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list) if not self._all_hosts: self._all_hosts = all_hosts