Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@ exclude: '^tutorials/input'

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v6.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 24.4.2
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 26.1.0
hooks:
- id: black
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
rev: 7.3.0
hooks:
- id: flake8
exclude: ^tutorials
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.0
rev: v1.19.1
hooks:
- id: mypy
exclude: ^tutorials
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
rev: 8.0.1
hooks:
- id: isort
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

**FEMR** is a Python package for manipulating longitudinal EHR data for machine learning, with a focus on supporting the creation of foundation models and verifying their [presumed benefits](https://hai.stanford.edu/news/how-foundation-models-can-advance-ai-healthcare) in healthcare. Such a framework is needed given the [current state of large language models in healthcare](https://hai.stanford.edu/news/shaky-foundations-foundation-models-healthcare) and the need for better evaluation frameworks.

The currently supported foundation models is [MOTOR](https://arxiv.org/abs/2301.03150).
The currently supported foundation models is [MOTOR](https://arxiv.org/abs/2301.03150).

(Users who want to train auto-regressive CLMBR-style models should use [FEMR 0.1.16](https://github.com/som-shahlab/femr/releases/tag/0.1.16) or https://github.com/som-shahlab/hf_ehr)

Expand Down
9 changes: 6 additions & 3 deletions src/femr/featurizers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,11 @@ def join_labels(features: Mapping[str, np.ndarray], labels: pd.DataFrame) -> Map
feature_index = 0

for label in labels.itertuples(index=False):
while ((feature_index + 1) < len(order)):
next_key = (features['subject_ids'][order[feature_index + 1]], features["feature_times"][order[feature_index + 1]])
while (feature_index + 1) < len(order):
next_key = (
features["subject_ids"][order[feature_index + 1]],
features["feature_times"][order[feature_index + 1]],
)
if next_key <= (label.subject_id, label.prediction_time):
feature_index += 1
else:
Expand All @@ -331,7 +334,7 @@ def join_labels(features: Mapping[str, np.ndarray], labels: pd.DataFrame) -> Map
and (features["subject_ids"][order[feature_index]] == label.subject_id)
and (features["feature_times"][order[feature_index]] <= label.prediction_time)
)

assert is_valid, (
f'{feature_index} {label} {features["subject_ids"][order[feature_index]]} '
+ f'{features["feature_times"][order[feature_index]]} {len(order)} {next_key}'
Expand Down
8 changes: 4 additions & 4 deletions src/femr/featurizers/featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,11 @@ def get_codes(self, code: str) -> Iterator[str]:
yield code

def get_columns(self, event: meds_reader.Event) -> Iterator[int]:
if getattr(event, 'text_value', None) is not None:
if getattr(event, "text_value", None) is not None:
k = (event.code, event.text_value[: self.characters_for_string_values])
if k in self.code_string_to_column_index:
yield self.code_string_to_column_index[k]
elif getattr(event, 'numeric_value', None) is not None:
elif getattr(event, "numeric_value", None) is not None:
if event.code in self.code_value_to_column_index:
column, quantiles = self.code_value_to_column_index[event.code]
for i, (start, end) in enumerate(zip(quantiles, quantiles[1:])):
Expand Down Expand Up @@ -276,10 +276,10 @@ def add_preprocess_data(
if self.excluded_event_filter is not None and self.excluded_event_filter(event):
continue

if getattr(event, 'text_value', None) is not None:
if getattr(event, "text_value", None) is not None:
if self.string_value_combination:
observed_string_value[(event.code, event.text_value[: self.characters_for_string_values])] += 1
elif getattr(event, 'numeric_value', None) is not None:
elif getattr(event, "numeric_value", None) is not None:
if self.numeric_value_decile:
observed_numeric_value[event.code].add(event.numeric_value)
else:
Expand Down
33 changes: 23 additions & 10 deletions src/femr/models/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,14 @@ def start_batch(self):
if self.task is not None:
self.task.start_batch()

def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length: Optional[int] = None, subsample_task_fraction: float = 1, actually_add: bool = True):
def add_subject(
self,
subject: meds_reader.Subject,
offset: int = 0,
max_length: Optional[int] = None,
subsample_task_fraction: float = 1,
actually_add: bool = True,
):
"""Add a subject to the current batch.

Note that the two optional parameters are used to add a subset of a subject to a batch.
Expand Down Expand Up @@ -174,7 +181,7 @@ def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length:
per_subject_ages.append((birth - birth) / datetime.timedelta(days=1))
per_subject_time_data.append([1, 0, 0, 0, 0])
per_subject_timestamps.append(event.time.replace(tzinfo=datetime.timezone.utc).timestamp())

for event in subject.events:
if event.time is None or event.time.date() <= birth.date():
continue
Expand Down Expand Up @@ -211,7 +218,6 @@ def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length:
for _ in range(num_added):
per_subject_label_indices.append(len(per_subject_ages) - 1)


if isinstance(self.tokenizer, femr.models.tokenizer.HierarchicalTokenizer):
assert weights is not None
per_subject_hierarchical_tokens.extend(features)
Expand Down Expand Up @@ -264,20 +270,22 @@ def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length:
self.time_data[start_index] = per_subject_time_data[0]
self.timestamps[start_index] = per_subject_timestamps[0]

if False: #not self.tokenizer.is_hierarchical:
if False: # not self.tokenizer.is_hierarchical:
# Easy for simple tokenizer
self.tokens.extend(per_subject_tokens[offset : offset + length_to_add])
elif isinstance(self.tokenizer, femr.models.tokenizer.HierarchicalTokenizer):
# Hierarchical tokenizer is more complex since we have to shift the indices as well
# Remember, these arrays are all designed for PyTorch EmbeddingBag

# We need to get the start and end at a particular offset
assert offset < len(per_subject_token_indices), f'Got it {len(per_subject_token_indices)} {subject.subject_id} {offset} {max_length}'
assert offset < len(
per_subject_token_indices
), f"Got it {len(per_subject_token_indices)} {subject.subject_id} {offset} {max_length}"

if offset == 0:
actual_offset = 0
actual_length = length_to_add
else:
else:
actual_offset = offset + 1
actual_length = length_to_add - 1

Expand All @@ -288,7 +296,7 @@ def add_subject(self, subject: meds_reader.Subject, offset: int = 0, max_length:
self.token_indices.append(len(self.hierarchical_tokens) + birth_end - birth_start)
self.hierarchical_tokens.extend(per_subject_hierarchical_tokens[birth_start:birth_end])
self.hierarchical_weights.extend(per_subject_hierarchical_weights[birth_start:birth_end])

internal_start = per_subject_token_indices[actual_offset]
internal_end = per_subject_token_indices[actual_offset + actual_length]

Expand Down Expand Up @@ -338,7 +346,7 @@ def get_batch_data(self):
"label_indices": np.array(self.label_indices, dtype=np.int32),
}

if False: #not self.tokenizer.is_hierarchical:
if False: # not self.tokenizer.is_hierarchical:
# For a single tokenizer, these are simple the token indices
transformer["tokens"] = np.array(self.tokens, dtype=token_dtype)
elif isinstance(self.tokenizer, femr.models.tokenizer.HierarchicalTokenizer):
Expand Down Expand Up @@ -383,7 +391,12 @@ def _batch_generator(batch_data: Tuple[np.ndarray, np.ndarray], *, creator: Batc
for i, (start, end) in enumerate(zip(offsets, offsets[1:])):
creator.start_batch()
for subject_index, offset, length, subsample_task_fraction in lengths[start:end, :]:
creator.add_subject(database[subject_index.item()], offset, length, subsample_task_fraction=float(subsample_task_fraction)/1e6)
creator.add_subject(
database[subject_index.item()],
offset,
length,
subsample_task_fraction=float(subsample_task_fraction) / 1e6,
)

result = creator.get_batch_data()
assert "task" in result, f"No task present in {lengths[start:end, :]} {i} {start} {end}"
Expand Down Expand Up @@ -510,7 +523,7 @@ def convert_dataset(
lengths_part = lengths[start:end, :]

for j, (a, b) in enumerate(batch_part):
assert a != b, f'{a} {b} {i} {j}'
assert a != b, f"{a} {b} {i} {j}"

offsets = [0] + [b - start for _, b in batch_part]

Expand Down
27 changes: 16 additions & 11 deletions src/femr/models/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
import collections
import datetime
import functools
import random
import warnings
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Set, Tuple

import meds
import meds_reader
import numpy as np
import scipy.sparse
import torch
import warnings

import femr.models.config
import femr.models.tokenizer
import femr.ontology
import femr.pat_utils
import femr.stat_utils
import random


class Task(abc.ABC):
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(self, labels: Sequence[meds.Label]):

self.label_map: Mapping[int, Any] = collections.defaultdict(list)
for label in labels:
self.label_map[label['subject_id']].append(label)
self.label_map[label["subject_id"]].append(label)

for k, v in self.label_map.items():
v.sort(key=lambda a: a["prediction_time"])
Expand Down Expand Up @@ -92,7 +92,7 @@ def add_event(
current_date: datetime.datetime,
next_date: Optional[datetime.datetime],
next_features: Optional[Sequence[int]] = None,
actually_add: bool = False
actually_add: bool = False,
) -> int:
has_label = False

Expand Down Expand Up @@ -170,6 +170,7 @@ def add_event(
def get_batch_data(self) -> Mapping[str, np.ndarray]:
return {"labels": np.array(self.batch_labels, dtype=np.int32)}


class SurvivalCalculator:
def __init__(
self, ontology: femr.ontology.Ontology, subject: meds_reader.Subject, code_whitelist: Optional[Set[str]] = None
Expand All @@ -181,7 +182,7 @@ def __init__(
for event in subject.events:
if event.time is None:
continue
if getattr(event, 'numeric_value', None) is not None or getattr(event, 'text_value', None) is not None:
if getattr(event, "numeric_value", None) is not None or getattr(event, "text_value", None) is not None:
continue

codes = set()
Expand Down Expand Up @@ -227,7 +228,11 @@ def _prefit_motor_map(
birth = femr.pat_utils.get_subject_birthdate(subject)

for event, next_event in zip(subject.events, subject.events[1:]):
if (event.time is None) or (event.time.date() == birth.date()) or (event.time.date() == next_event.time.date()):
if (
(event.time is None)
or (event.time.date() == birth.date())
or (event.time.date() == next_event.time.date())
):
continue

censor_time, tte = calculator.get_future_events_for_time(event.time)
Expand Down Expand Up @@ -271,7 +276,7 @@ def fit_pretraining_task_info(
num_tasks: int,
num_bins: int,
final_layer_size: int,
min_fraction: float = 1/1000,
min_fraction: float = 1 / 1000,
) -> MOTORTask:
tasks = []
for dict_entry in tokenizer.dictionary["vocab"]:
Expand Down Expand Up @@ -299,15 +304,15 @@ def fit_pretraining_task_info(
rate = frac_events / task_stats[2].mean()

if rate == 0:
# print("Ran into task of rate 0?", task, frac_events, task_stats[0], task_stats[1], task_stats[2].mean())
# print("Ran into task of rate 0?", task, frac_events, task_stats[0], task_stats[1], task_stats[2].mean())
continue

if frac_events < min_fraction:
# print("Ran into very rare task with less than 10 occurrences", task, frac_events, task_stats[0], task_stats[1], task_stats[2].mean())
continue

task_data.append((task, rate, task_stats[0], task_stats[1], task_stats[2].mean()))

return MOTORTask(task_data, time_bins, final_layer_size)

def __init__(self, pretraining_task_info: List[Tuple[str, float]], time_bins: List[float], final_layer_size: int):
Expand Down Expand Up @@ -384,7 +389,7 @@ def add_event(

if len(tte) == 0:
return 0

if not actually_add:
return 1

Expand Down
2 changes: 1 addition & 1 deletion src/femr/models/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .hierarchical_tokenizer import HierarchicalTokenizer
from .hierarchical_tokenizer import HierarchicalTokenizer
15 changes: 8 additions & 7 deletions src/femr/models/tokenizer/flat_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import meds_reader
import msgpack
import numpy as np
import pyarrow as pa
import transformers

import femr.ontology
import femr.stat_utils
import femr.pat_utils
import pyarrow as pa
import femr.stat_utils


def train_tokenizer(
Expand All @@ -36,9 +36,7 @@ def train_tokenizer(
),
)

return FlatTokenizer(
convert_statistics_to_msgpack(statistics, vocab_size, num_numeric)
)
return FlatTokenizer(convert_statistics_to_msgpack(statistics, vocab_size, num_numeric))


def agg_statistics(stats1, stats2):
Expand All @@ -61,6 +59,7 @@ def normalize_unit(unit):
else:
return None


def is_close_float(t, f):
if f is None:
return False
Expand All @@ -70,6 +69,7 @@ def is_close_float(t, f):
except:
return False


def map_statistics(
subjects: Iterator[meds_reader.Subject],
*,
Expand Down Expand Up @@ -111,7 +111,9 @@ def map_statistics(


def convert_statistics_to_msgpack(
statistics, vocab_size: int, num_numeric: int,
statistics,
vocab_size: int,
num_numeric: int,
):
vocab = []

Expand Down Expand Up @@ -163,7 +165,6 @@ def convert_statistics_to_msgpack(
"weight": weight * math.log(weight) + (1 - weight) * math.log(1 - weight),
}
vocab.append(entry)


vocab.sort(key=lambda a: a["weight"])
vocab = vocab[:vocab_size]
Expand Down
Loading