Skip to content

Commit 3ca50c7

Browse files
padr31copybara-github
authored andcommitted
Move track_data proto conversion functions to track_data_utils.
PiperOrigin-RevId: 796397700 Change-Id: I80043b8fa5ef330fc884372be5ddeb76725de0df
1 parent fc6a711 commit 3ca50c7

6 files changed

Lines changed: 504 additions & 438 deletions

File tree

src/alphagenome/data/track_data.py

Lines changed: 1 addition & 233 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,15 @@
1515

1616
"""Track data container analogous to AnnData."""
1717

18-
from collections.abc import Iterable, Sequence
18+
from collections.abc import Sequence
1919
import copy
2020
import dataclasses
2121
import enum
2222
from typing import Any, Union
2323

24-
from alphagenome import tensor_utils
2524
from alphagenome import typing
2625
from alphagenome.data import genome
2726
from alphagenome.data import ontology
28-
from alphagenome.protos import dna_model_pb2
29-
from alphagenome.protos import tensor_pb2
3027
from jaxtyping import Bool, Float32, Int32 # pylint: disable=g-multiple-import, g-importing-member
3128
import numpy as np
3229
import pandas as pd
@@ -686,39 +683,6 @@ def __sub__(self, other: 'TrackData') -> 'TrackData':
686683
uns=self.uns,
687684
)
688685

689-
def to_protos(
690-
self,
691-
*,
692-
bytes_per_chunk: int = 0,
693-
compression_type: tensor_pb2.CompressionType = (
694-
tensor_pb2.CompressionType.COMPRESSION_TYPE_NONE
695-
),
696-
) -> tuple[dna_model_pb2.TrackData, Sequence[tensor_pb2.TensorChunk]]:
697-
"""Serializes `TrackData` to protobuf messages.
698-
699-
Args:
700-
bytes_per_chunk: The maximum number of bytes per tensor chunk.
701-
compression_type: The compression type to use for the tensor chunks.
702-
703-
Returns:
704-
A tuple containing the `TrackData` protobuf message and a sequence of
705-
`TensorChunk` protobuf messages.
706-
"""
707-
tensor, chunks = tensor_utils.pack_tensor(
708-
self.values,
709-
bytes_per_chunk=bytes_per_chunk,
710-
compression_type=compression_type,
711-
)
712-
return (
713-
dna_model_pb2.TrackData(
714-
values=tensor,
715-
metadata=metadata_to_proto(self.metadata).metadata,
716-
resolution=self.resolution,
717-
interval=self.interval.to_proto() if self.interval else None,
718-
),
719-
chunks,
720-
)
721-
722686

723687
def concat(
724688
track_datas: Sequence[TrackData],
@@ -869,199 +833,3 @@ def interleave(
869833
interval=track_datas[0].interval,
870834
uns={'num_interleaved_trackdatas': len(track_datas)},
871835
)
872-
873-
874-
def metadata_to_proto(
875-
metadata: TrackMetadata,
876-
) -> dna_model_pb2.TracksMetadata:
877-
"""Converts track metadata to a `TracksMetadata` protobuf message.
878-
879-
Args:
880-
metadata: A pandas DataFrame containing track metadata.
881-
882-
Returns:
883-
A `TracksMetadata` protobuf message.
884-
"""
885-
names = metadata['name']
886-
default_values = [None] * len(names)
887-
888-
columns = zip(
889-
metadata['name'],
890-
metadata['strand'],
891-
metadata.get('ontology_curie', default_values),
892-
metadata.get('biosample_type', default_values),
893-
metadata.get('biosample_name', default_values),
894-
metadata.get('biosample_life_stage', default_values),
895-
metadata.get('transcription_factor', default_values),
896-
metadata.get('histone_mark', default_values),
897-
metadata.get('gtex_tissue', default_values),
898-
metadata.get('Assay title', default_values),
899-
metadata.get('data_source', default_values),
900-
metadata.get('genetically_modified', default_values),
901-
metadata.get('endedness', default_values),
902-
metadata.get('nonzero_mean', default_values),
903-
strict=True,
904-
)
905-
906-
metadata_protos = []
907-
for (
908-
name,
909-
strand,
910-
ontology_curie,
911-
biosample_type,
912-
biosample_name,
913-
biosample_life_stage,
914-
transcription_factor,
915-
histone_mark,
916-
gtex_tissue,
917-
assay,
918-
data_source,
919-
genetically_modified,
920-
endedness,
921-
nonzero_mean,
922-
) in columns:
923-
if biosample_type:
924-
biosample = dna_model_pb2.Biosample(
925-
name=biosample_name,
926-
type=dna_model_pb2.BiosampleType.Value(
927-
f'BIOSAMPLE_TYPE_{biosample_type.upper()}'
928-
),
929-
stage=biosample_life_stage,
930-
)
931-
else:
932-
biosample = None
933-
if endedness is not None:
934-
match endedness:
935-
case 'paired':
936-
endedness = dna_model_pb2.Endedness.ENDEDNESS_PAIRED
937-
case 'single':
938-
endedness = dna_model_pb2.Endedness.ENDEDNESS_SINGLE
939-
case _:
940-
raise ValueError(f'Unknown endedness: {endedness}')
941-
942-
metadata_protos.append(
943-
dna_model_pb2.TrackMetadata(
944-
name=name,
945-
strand=genome.Strand.from_str(strand).to_proto()
946-
if strand
947-
else None,
948-
ontology_term=ontology.from_curie(ontology_curie).to_proto()
949-
if ontology_curie
950-
else None,
951-
biosample=biosample,
952-
transcription_factor_code=transcription_factor
953-
if isinstance(transcription_factor, str)
954-
else None,
955-
histone_mark_code=histone_mark
956-
if isinstance(histone_mark, str)
957-
else None,
958-
gtex_tissue=gtex_tissue,
959-
assay=assay,
960-
data_source=data_source,
961-
genetically_modified=genetically_modified,
962-
endedness=endedness,
963-
nonzero_mean=nonzero_mean,
964-
)
965-
)
966-
967-
return dna_model_pb2.TracksMetadata(metadata=metadata_protos)
968-
969-
970-
def metadata_from_proto(
971-
proto: dna_model_pb2.TracksMetadata,
972-
) -> TrackMetadata:
973-
"""Creates track metadata from a `TracksMetadata` protobuf message.
974-
975-
Args:
976-
proto: A `TracksMetadata` protobuf message.
977-
978-
Returns:
979-
A pandas DataFrame containing track metadata.
980-
"""
981-
metadata = []
982-
for track_proto in proto.metadata:
983-
track_metadata = {
984-
'name': track_proto.name,
985-
'strand': str(genome.Strand.from_proto(track_proto.strand)),
986-
}
987-
988-
if track_proto.HasField('assay'):
989-
track_metadata['Assay title'] = track_proto.assay
990-
991-
if track_proto.HasField('ontology_term'):
992-
track_metadata['ontology_curie'] = ontology.from_proto(
993-
track_proto.ontology_term
994-
).ontology_curie
995-
996-
if track_proto.HasField('biosample'):
997-
track_metadata['biosample_name'] = track_proto.biosample.name
998-
track_metadata['biosample_type'] = (
999-
dna_model_pb2.BiosampleType.Name(track_proto.biosample.type)
1000-
.removeprefix('BIOSAMPLE_TYPE_')
1001-
.lower()
1002-
)
1003-
if track_proto.biosample.HasField('stage'):
1004-
track_metadata['biosample_life_stage'] = track_proto.biosample.stage
1005-
1006-
if track_proto.HasField('transcription_factor_code'):
1007-
track_metadata['transcription_factor'] = (
1008-
track_proto.transcription_factor_code
1009-
)
1010-
1011-
if track_proto.HasField('histone_mark_code'):
1012-
track_metadata['histone_mark'] = track_proto.histone_mark_code
1013-
1014-
if track_proto.HasField('gtex_tissue'):
1015-
track_metadata['gtex_tissue'] = track_proto.gtex_tissue
1016-
1017-
if track_proto.HasField('data_source'):
1018-
track_metadata['data_source'] = track_proto.data_source
1019-
1020-
if track_proto.HasField('endedness'):
1021-
track_metadata['endedness'] = (
1022-
dna_model_pb2.Endedness.Name(track_proto.endedness)
1023-
.removeprefix('ENDEDNESS_')
1024-
.lower()
1025-
)
1026-
1027-
if track_proto.HasField('genetically_modified'):
1028-
track_metadata['genetically_modified'] = track_proto.genetically_modified
1029-
1030-
if track_proto.HasField('nonzero_mean'):
1031-
track_metadata['nonzero_mean'] = track_proto.nonzero_mean
1032-
1033-
metadata.append(track_metadata)
1034-
if metadata:
1035-
return pd.DataFrame(metadata)
1036-
else:
1037-
return pd.DataFrame(columns=['name', 'strand'])
1038-
1039-
1040-
def from_protos(
1041-
proto: dna_model_pb2.TrackData,
1042-
chunks: Iterable[tensor_pb2.TensorChunk] = (),
1043-
*,
1044-
interval: genome.Interval | None = None,
1045-
) -> TrackData:
1046-
"""Creates a `TrackData` object from protobuf messages.
1047-
1048-
Args:
1049-
proto: A `TrackData` protobuf message.
1050-
chunks: A sequence of `TensorChunk` protobuf messages.
1051-
interval: Optional `Interval` object representing the genomic region
1052-
containing the tracks. Only used if the proto does not have an interval.
1053-
1054-
Returns:
1055-
A `TrackData` object.
1056-
"""
1057-
metadata = metadata_from_proto(
1058-
dna_model_pb2.TracksMetadata(metadata=proto.metadata)
1059-
)
1060-
1061-
values = tensor_utils.unpack_proto(proto.values, chunks)
1062-
values = tensor_utils.upcast_floating(values)
1063-
resolution = proto.resolution if proto.HasField('resolution') else 1
1064-
if proto.HasField('interval'):
1065-
interval = genome.Interval.from_proto(proto.interval)
1066-
1067-
return TrackData(values, metadata, resolution=resolution, interval=interval)

0 commit comments

Comments
 (0)