Skip to content

Commit

Permalink
[Refactor] Put values, lengths and offsets of NJTs together in storage
Browse files Browse the repository at this point in the history
ghstack-source-id: d27b2ce9ada200e531d3e8dfbe462e58217334ba
Pull Request resolved: #1023
  • Loading branch information
vmoens committed Oct 16, 2024
1 parent 321f662 commit cfa618a
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 102 deletions.
86 changes: 64 additions & 22 deletions tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import copyreg
import queue
from multiprocessing.reduction import ForkingPickler

import torch
Expand All @@ -21,7 +22,17 @@


def _rebuild_tensordict_files(flat_key_values, metadata_dict, is_shared: bool = False):
_nt_values_and_keys = queue.Queue()
_nt_lengths = queue.Queue()
_nt_offsets = queue.Queue()

def from_metadata(metadata=metadata_dict, prefix=None):
metadata = dict(metadata)

_ = metadata.pop("njt_values_start", None)
_ = metadata.pop("njt_lengths_start", None)
_ = metadata.pop("njt_offsets_start", None)

non_tensor = metadata.pop("non_tensors")
leaves = metadata.pop("leaves")
cls = metadata.pop("cls")
Expand All @@ -36,22 +47,21 @@ def from_metadata(metadata=metadata_dict, prefix=None):
total_key = (key,) if prefix is None else prefix + (key,)
if total_key[-1].startswith("<NJT>"):
nested_values = flat_key_values[total_key]
nested_lengths = None
total_key = total_key[:-1] + total_key[-1].replace("<NJT>", "")
_nt_values_and_keys.put((nested_values, total_key))
continue
if total_key[-1].startswith("<NJT_LENGTHS>"):
nested_lengths = flat_key_values[total_key]
_nt_lengths.put(nested_lengths)
continue
elif total_key[-1].startswith("<NJT_OFFSETS"):
offsets = flat_key_values[total_key]
key = key.replace("<NJT_OFFSETS>", "")
value = torch.nested.nested_tensor_from_jagged(
nested_values, offsets=offsets, lengths=nested_lengths
)
del nested_values
del nested_lengths
_nt_offsets.put(offsets)
continue
else:
value = flat_key_values[total_key]
d[key] = value

for k, v in metadata.items():
# Each remaining key is a tuple pointing to a sub-tensordict
d[k] = from_metadata(
Expand All @@ -64,7 +74,18 @@ def from_metadata(metadata=metadata_dict, prefix=None):
# result._is_shared = is_shared
return result

return from_metadata()
result = from_metadata()
# Then assign the nested tensors
while not _nt_values_and_keys.empty():
vals, key = _nt_values_and_keys.get()
lengths = _nt_lengths.get()
offsets = _nt_offsets.get()
value = torch.nested.nested_tensor_from_jagged(
vals, offsets=offsets, lengths=lengths
)
result._set_tuple(key, value, inplace=False, validated=True)

return result


def _rebuild_tensordict_files_shared(flat_key_values, metadata_dict):
Expand All @@ -75,9 +96,18 @@ def _rebuild_tensordict_files_consolidated(
metadata,
storage,
):
_nt_values_and_keys = queue.Queue()
_nt_lengths = queue.Queue()
_nt_offsets = queue.Queue()

def from_metadata(metadata=metadata, prefix=None):
consolidated = {"storage": storage, "metadata": metadata}
metadata = dict(metadata)

_ = metadata.pop("njt_values_start", None)
_ = metadata.pop("njt_lengths_start", None)
_ = metadata.pop("njt_offsets_start", None)

non_tensor = metadata.pop("non_tensors")
leaves = metadata.pop("leaves")
cls = metadata.pop("cls")
Expand All @@ -101,33 +131,45 @@ def from_metadata(metadata=metadata, prefix=None):
if key.startswith("<NJT>"):
raise RuntimeError
elif key.startswith("<NJT_VALUES>"):
nested_values = value
nested_lengths = None
key = key.replace("<NJT_VALUES>", "")
if prefix:
total_key = prefix + (key,)
else:
total_key = (key,)
_nt_values_and_keys.put((value, total_key))
continue
elif key.startswith("<NJT_LENGTHS>"):
nested_lengths = value
_nt_lengths.put(value)
continue
elif key.startswith("<NJT_OFFSETS>"):
from torch.nested._internal.nested_tensor import NestedTensor

offsets = value
value = NestedTensor(
nested_values, offsets=offsets, lengths=nested_lengths
)
key = key.replace("<NJT_OFFSETS>", "")
_nt_offsets.put(value)
if _nt_offsets.qsize() > _nt_lengths.qsize():
_nt_lengths.put(None)
continue
d[key] = value
for k, v in metadata.items():
for key, val in metadata.items():
# Each remaining key is a tuple pointing to a sub-tensordict
d[k] = from_metadata(
v, prefix=prefix + (k,) if prefix is not None else (k,)
d[key] = from_metadata(
val, prefix=prefix + (key,) if prefix is not None else (key,)
)
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
if is_locked:
result = result.lock_()
result._consolidated = consolidated
return result

return from_metadata()
result = from_metadata()
# Then assign the nested tensors
while not _nt_values_and_keys.empty():
vals, key = _nt_values_and_keys.get()
lengths = _nt_lengths.get()
offsets = _nt_offsets.get()
value = torch.nested.nested_tensor_from_jagged(
vals, offsets=offsets, lengths=lengths
)
result._set_tuple(key, value, inplace=False, validated=True)

return result


def _make_td(cls, state):
Expand Down
Loading

0 comments on commit cfa618a

Please sign in to comment.