|
15 | 15 |
|
16 | 16 | """Track data container analogous to AnnData.""" |
17 | 17 |
|
18 | | -from collections.abc import Iterable, Sequence |
| 18 | +from collections.abc import Sequence |
19 | 19 | import copy |
20 | 20 | import dataclasses |
21 | 21 | import enum |
22 | 22 | from typing import Any, Union |
23 | 23 |
|
24 | | -from alphagenome import tensor_utils |
25 | 24 | from alphagenome import typing |
26 | 25 | from alphagenome.data import genome |
27 | 26 | from alphagenome.data import ontology |
28 | | -from alphagenome.protos import dna_model_pb2 |
29 | | -from alphagenome.protos import tensor_pb2 |
30 | 27 | from jaxtyping import Bool, Float32, Int32 # pylint: disable=g-multiple-import, g-importing-member |
31 | 28 | import numpy as np |
32 | 29 | import pandas as pd |
@@ -686,39 +683,6 @@ def __sub__(self, other: 'TrackData') -> 'TrackData': |
686 | 683 | uns=self.uns, |
687 | 684 | ) |
688 | 685 |
|
689 | | - def to_protos( |
690 | | - self, |
691 | | - *, |
692 | | - bytes_per_chunk: int = 0, |
693 | | - compression_type: tensor_pb2.CompressionType = ( |
694 | | - tensor_pb2.CompressionType.COMPRESSION_TYPE_NONE |
695 | | - ), |
696 | | - ) -> tuple[dna_model_pb2.TrackData, Sequence[tensor_pb2.TensorChunk]]: |
697 | | - """Serializes `TrackData` to protobuf messages. |
698 | | -
|
699 | | - Args: |
700 | | - bytes_per_chunk: The maximum number of bytes per tensor chunk. |
701 | | - compression_type: The compression type to use for the tensor chunks. |
702 | | -
|
703 | | - Returns: |
704 | | - A tuple containing the `TrackData` protobuf message and a sequence of |
705 | | - `TensorChunk` protobuf messages. |
706 | | - """ |
707 | | - tensor, chunks = tensor_utils.pack_tensor( |
708 | | - self.values, |
709 | | - bytes_per_chunk=bytes_per_chunk, |
710 | | - compression_type=compression_type, |
711 | | - ) |
712 | | - return ( |
713 | | - dna_model_pb2.TrackData( |
714 | | - values=tensor, |
715 | | - metadata=metadata_to_proto(self.metadata).metadata, |
716 | | - resolution=self.resolution, |
717 | | - interval=self.interval.to_proto() if self.interval else None, |
718 | | - ), |
719 | | - chunks, |
720 | | - ) |
721 | | - |
722 | 686 |
|
723 | 687 | def concat( |
724 | 688 | track_datas: Sequence[TrackData], |
@@ -869,199 +833,3 @@ def interleave( |
869 | 833 | interval=track_datas[0].interval, |
870 | 834 | uns={'num_interleaved_trackdatas': len(track_datas)}, |
871 | 835 | ) |
872 | | - |
873 | | - |
874 | | -def metadata_to_proto( |
875 | | - metadata: TrackMetadata, |
876 | | -) -> dna_model_pb2.TracksMetadata: |
877 | | - """Converts track metadata to a `TracksMetadata` protobuf message. |
878 | | -
|
879 | | - Args: |
880 | | - metadata: A pandas DataFrame containing track metadata. |
881 | | -
|
882 | | - Returns: |
883 | | - A `TracksMetadata` protobuf message. |
884 | | - """ |
885 | | - names = metadata['name'] |
886 | | - default_values = [None] * len(names) |
887 | | - |
888 | | - columns = zip( |
889 | | - metadata['name'], |
890 | | - metadata['strand'], |
891 | | - metadata.get('ontology_curie', default_values), |
892 | | - metadata.get('biosample_type', default_values), |
893 | | - metadata.get('biosample_name', default_values), |
894 | | - metadata.get('biosample_life_stage', default_values), |
895 | | - metadata.get('transcription_factor', default_values), |
896 | | - metadata.get('histone_mark', default_values), |
897 | | - metadata.get('gtex_tissue', default_values), |
898 | | - metadata.get('Assay title', default_values), |
899 | | - metadata.get('data_source', default_values), |
900 | | - metadata.get('genetically_modified', default_values), |
901 | | - metadata.get('endedness', default_values), |
902 | | - metadata.get('nonzero_mean', default_values), |
903 | | - strict=True, |
904 | | - ) |
905 | | - |
906 | | - metadata_protos = [] |
907 | | - for ( |
908 | | - name, |
909 | | - strand, |
910 | | - ontology_curie, |
911 | | - biosample_type, |
912 | | - biosample_name, |
913 | | - biosample_life_stage, |
914 | | - transcription_factor, |
915 | | - histone_mark, |
916 | | - gtex_tissue, |
917 | | - assay, |
918 | | - data_source, |
919 | | - genetically_modified, |
920 | | - endedness, |
921 | | - nonzero_mean, |
922 | | - ) in columns: |
923 | | - if biosample_type: |
924 | | - biosample = dna_model_pb2.Biosample( |
925 | | - name=biosample_name, |
926 | | - type=dna_model_pb2.BiosampleType.Value( |
927 | | - f'BIOSAMPLE_TYPE_{biosample_type.upper()}' |
928 | | - ), |
929 | | - stage=biosample_life_stage, |
930 | | - ) |
931 | | - else: |
932 | | - biosample = None |
933 | | - if endedness is not None: |
934 | | - match endedness: |
935 | | - case 'paired': |
936 | | - endedness = dna_model_pb2.Endedness.ENDEDNESS_PAIRED |
937 | | - case 'single': |
938 | | - endedness = dna_model_pb2.Endedness.ENDEDNESS_SINGLE |
939 | | - case _: |
940 | | - raise ValueError(f'Unknown endedness: {endedness}') |
941 | | - |
942 | | - metadata_protos.append( |
943 | | - dna_model_pb2.TrackMetadata( |
944 | | - name=name, |
945 | | - strand=genome.Strand.from_str(strand).to_proto() |
946 | | - if strand |
947 | | - else None, |
948 | | - ontology_term=ontology.from_curie(ontology_curie).to_proto() |
949 | | - if ontology_curie |
950 | | - else None, |
951 | | - biosample=biosample, |
952 | | - transcription_factor_code=transcription_factor |
953 | | - if isinstance(transcription_factor, str) |
954 | | - else None, |
955 | | - histone_mark_code=histone_mark |
956 | | - if isinstance(histone_mark, str) |
957 | | - else None, |
958 | | - gtex_tissue=gtex_tissue, |
959 | | - assay=assay, |
960 | | - data_source=data_source, |
961 | | - genetically_modified=genetically_modified, |
962 | | - endedness=endedness, |
963 | | - nonzero_mean=nonzero_mean, |
964 | | - ) |
965 | | - ) |
966 | | - |
967 | | - return dna_model_pb2.TracksMetadata(metadata=metadata_protos) |
968 | | - |
969 | | - |
970 | | -def metadata_from_proto( |
971 | | - proto: dna_model_pb2.TracksMetadata, |
972 | | -) -> TrackMetadata: |
973 | | - """Creates track metadata from a `TracksMetadata` protobuf message. |
974 | | -
|
975 | | - Args: |
976 | | - proto: A `TracksMetadata` protobuf message. |
977 | | -
|
978 | | - Returns: |
979 | | - A pandas DataFrame containing track metadata. |
980 | | - """ |
981 | | - metadata = [] |
982 | | - for track_proto in proto.metadata: |
983 | | - track_metadata = { |
984 | | - 'name': track_proto.name, |
985 | | - 'strand': str(genome.Strand.from_proto(track_proto.strand)), |
986 | | - } |
987 | | - |
988 | | - if track_proto.HasField('assay'): |
989 | | - track_metadata['Assay title'] = track_proto.assay |
990 | | - |
991 | | - if track_proto.HasField('ontology_term'): |
992 | | - track_metadata['ontology_curie'] = ontology.from_proto( |
993 | | - track_proto.ontology_term |
994 | | - ).ontology_curie |
995 | | - |
996 | | - if track_proto.HasField('biosample'): |
997 | | - track_metadata['biosample_name'] = track_proto.biosample.name |
998 | | - track_metadata['biosample_type'] = ( |
999 | | - dna_model_pb2.BiosampleType.Name(track_proto.biosample.type) |
1000 | | - .removeprefix('BIOSAMPLE_TYPE_') |
1001 | | - .lower() |
1002 | | - ) |
1003 | | - if track_proto.biosample.HasField('stage'): |
1004 | | - track_metadata['biosample_life_stage'] = track_proto.biosample.stage |
1005 | | - |
1006 | | - if track_proto.HasField('transcription_factor_code'): |
1007 | | - track_metadata['transcription_factor'] = ( |
1008 | | - track_proto.transcription_factor_code |
1009 | | - ) |
1010 | | - |
1011 | | - if track_proto.HasField('histone_mark_code'): |
1012 | | - track_metadata['histone_mark'] = track_proto.histone_mark_code |
1013 | | - |
1014 | | - if track_proto.HasField('gtex_tissue'): |
1015 | | - track_metadata['gtex_tissue'] = track_proto.gtex_tissue |
1016 | | - |
1017 | | - if track_proto.HasField('data_source'): |
1018 | | - track_metadata['data_source'] = track_proto.data_source |
1019 | | - |
1020 | | - if track_proto.HasField('endedness'): |
1021 | | - track_metadata['endedness'] = ( |
1022 | | - dna_model_pb2.Endedness.Name(track_proto.endedness) |
1023 | | - .removeprefix('ENDEDNESS_') |
1024 | | - .lower() |
1025 | | - ) |
1026 | | - |
1027 | | - if track_proto.HasField('genetically_modified'): |
1028 | | - track_metadata['genetically_modified'] = track_proto.genetically_modified |
1029 | | - |
1030 | | - if track_proto.HasField('nonzero_mean'): |
1031 | | - track_metadata['nonzero_mean'] = track_proto.nonzero_mean |
1032 | | - |
1033 | | - metadata.append(track_metadata) |
1034 | | - if metadata: |
1035 | | - return pd.DataFrame(metadata) |
1036 | | - else: |
1037 | | - return pd.DataFrame(columns=['name', 'strand']) |
1038 | | - |
1039 | | - |
1040 | | -def from_protos( |
1041 | | - proto: dna_model_pb2.TrackData, |
1042 | | - chunks: Iterable[tensor_pb2.TensorChunk] = (), |
1043 | | - *, |
1044 | | - interval: genome.Interval | None = None, |
1045 | | -) -> TrackData: |
1046 | | - """Creates a `TrackData` object from protobuf messages. |
1047 | | -
|
1048 | | - Args: |
1049 | | - proto: A `TrackData` protobuf message. |
1050 | | - chunks: A sequence of `TensorChunk` protobuf messages. |
1051 | | - interval: Optional `Interval` object representing the genomic region |
1052 | | - containing the tracks. Only used if the proto does not have an interval. |
1053 | | -
|
1054 | | - Returns: |
1055 | | - A `TrackData` object. |
1056 | | - """ |
1057 | | - metadata = metadata_from_proto( |
1058 | | - dna_model_pb2.TracksMetadata(metadata=proto.metadata) |
1059 | | - ) |
1060 | | - |
1061 | | - values = tensor_utils.unpack_proto(proto.values, chunks) |
1062 | | - values = tensor_utils.upcast_floating(values) |
1063 | | - resolution = proto.resolution if proto.HasField('resolution') else 1 |
1064 | | - if proto.HasField('interval'): |
1065 | | - interval = genome.Interval.from_proto(proto.interval) |
1066 | | - |
1067 | | - return TrackData(values, metadata, resolution=resolution, interval=interval) |
0 commit comments