Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions nmdc_tables/table_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import requests

Check failure on line 1 in nmdc_tables/table_1.py

View workflow job for this annotation

GitHub Actions / Check code formatting

ruff (D100)

nmdc_tables/table_1.py:1:1: D100 Missing docstring in public module

Check failure on line 1 in nmdc_tables/table_1.py

View workflow job for this annotation

GitHub Actions / Check code formatting

ruff (INP001)

nmdc_tables/table_1.py:1:1: INP001 File `nmdc_tables/table_1.py` is part of an implicit namespace package. Add an `__init__.py`.
import logging

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.sql.functions import first

Check failure on line 6 in nmdc_tables/table_1.py

View workflow job for this annotation

GitHub Actions / Check code formatting

ruff (I001)

nmdc_tables/table_1.py:1:1: I001 Import block is un-sorted or un-formatted help: Organize imports


spark = SparkSession.builder.appName("NMDC Study Pipeline").getOrCreate()

BASE_URL = "https://api.microbiomedata.org"

## 4 studies for testing
STUDIES = [
"nmdc:sty-11-34xj1150",
"nmdc:sty-11-hht5sb92",
"nmdc:sty-11-nxrz9m96",
"nmdc:sty-11-pzmd0x14",
]

## capital words
ROLE_MAP = {
"Principal Investigator": "principal_investigator",
"principal_investigator": "principal_investigator",
"Methodology": "methodology",
"Data curation": "data_curation",
}

logging.basicConfig(level=logging.INFO)


## helper functions to normalize role, person_id and email
def normalize_role(role: str) -> str:

Check failure on line 33 in nmdc_tables/table_1.py

View workflow job for this annotation

GitHub Actions / Check code formatting

ruff (D103)

nmdc_tables/table_1.py:33:5: D103 Missing docstring in public function
if not role:
return None
role = role.strip()
return ROLE_MAP.get(role, role.lower().replace(" ", "_"))


def normalize_person_id(person: dict) -> str:

Check failure on line 40 in nmdc_tables/table_1.py

View workflow job for this annotation

GitHub Actions / Check code formatting

ruff (D103)

nmdc_tables/table_1.py:40:5: D103 Missing docstring in public function
if not person:
return None

if person.get("orcid"):
return f"orcid:{person['orcid']}"
elif person.get("name"):

Check failure on line 46 in nmdc_tables/table_1.py

View workflow job for this annotation

GitHub Actions / Check code formatting

ruff (RET505)

nmdc_tables/table_1.py:46:5: RET505 Unnecessary `elif` after `return` statement help: Remove unnecessary `elif`
return f"name:{person['name']}"
return None


def normalize_email(person: dict) -> str:

Check failure on line 51 in nmdc_tables/table_1.py

View workflow job for this annotation

GitHub Actions / Check code formatting

ruff (D103)

nmdc_tables/table_1.py:51:5: D103 Missing docstring in public function
if not person:
return None
email = person.get("email")
if email:
return email.strip().lower()
return None


# Fetch API data for a single study
def fetch_study(study_id):

Check failure on line 61 in nmdc_tables/table_1.py

View workflow job for this annotation

GitHub Actions / Check code formatting

ruff (ANN001)

nmdc_tables/table_1.py:61:17: ANN001 Missing type annotation for function argument `study_id`

Check failure on line 61 in nmdc_tables/table_1.py

View workflow job for this annotation

GitHub Actions / Check code formatting

ruff (D103)

nmdc_tables/table_1.py:61:5: D103 Missing docstring in public function

Check failure on line 61 in nmdc_tables/table_1.py

View workflow job for this annotation

GitHub Actions / Check code formatting

ruff (ANN201)

nmdc_tables/table_1.py:61:5: ANN201 Missing return type annotation for public function `fetch_study` help: Add return type annotation
url = f"{BASE_URL}/studies/{study_id}"
try:
res = requests.get(url, timeout=10)
res.raise_for_status()
return res.json()
except Exception as e:
logging.warning(f"Error fetching {study_id}: {e}")
return None


# Extract study-person
def extract_study_person(data):
if not data:
return []

entity_id = data.get("id")
rows = []

# PI
pi = data.get("principal_investigator", {})
if pi:
pid = normalize_person_id(pi)
email = normalize_email(pi)

if pid:
rows.append(
(
entity_id,
pid,
pi.get("name"),
email,
"principal_investigator",
)
)

# Contributors
for assoc in data.get("has_credit_associations", []):
person = assoc.get("applies_to_person", {})
roles = assoc.get("applied_roles", [])

pid = normalize_person_id(person)
email = normalize_email(person)

if not pid:
continue

for role in roles:
role_clean = normalize_role(role)

rows.append(
(
entity_id,
pid,
person.get("name"),
email,
role_clean,
)
)

return rows


# Schema
study_person_schema = StructType(
[
StructField("study_id", StringType(), True),
StructField("person_id", StringType(), True),
StructField("name", StringType(), True),
StructField("email", StringType(), True),
StructField("role", StringType(), True),
]
)


def main():
logging.info("starting NMDC Study Pipeline")

study_rdd = spark.sparkContext.parallelize(STUDIES)

def process_study(study_id):
logging.info(f"Processing {study_id}")
data = fetch_study(study_id)
return extract_study_person(data)

# flatMap to flatten rows
rows_rdd = study_rdd.flatMap(process_study)

df = spark.createDataFrame(rows_rdd, schema=study_person_schema)

study_person_spark = df.groupBy("study_id", "person_id", "role").agg(
first("name", ignorenulls=True).alias("name"),
first("email", ignorenulls=True).alias("email"),
)

## Debug
study_person_spark.show(truncate=False)
study_person_spark.printSchema()

## output to parquet
output_path = "output/study_person"
study_person_spark.write.mode("overwrite").parquet(output_path)

logging.info(f"Output saved to {output_path}")


if __name__ == "__main__":
main()
102 changes: 102 additions & 0 deletions nmdc_tables/table_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import time
import requests
import logging
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType


BASE_URL = "https://api.microbiomedata.org"
OUTPUT_PATH = "output/study_sample"

STUDIES = [
"nmdc:sty-11-34xj1150",
"nmdc:sty-11-hht5sb92",
"nmdc:sty-11-nxrz9m96",
"nmdc:sty-11-pzmd0x14",
]


logging.basicConfig(level=logging.INFO)

study_sample_schema = StructType(
[
StructField("study_id", StringType(), True),
StructField("biosample_id", StringType(), True),
]
)


def normalize_biosample_id(bid):
if not bid:
return None
return bid if bid.startswith("nmdc:") else f"nmdc:{bid}"


def fetch_biosample_ids(study_id, retries=3):
url = f"{BASE_URL}/data_objects/study/{study_id}"

for attempt in range(retries):
try:
res = requests.get(url, timeout=(5, 30))
res.raise_for_status()

try:
data = res.json()
except Exception:
logging.warning(f"{study_id}: JSON decode failed")
return []

biosamples = set()

for record in data:
try:
metadata = record.get("metadata") or {}
if not isinstance(metadata, dict):
metadata = {}

bid = record.get("biosample_id") or metadata.get("biosample_id")

if bid:
biosamples.add(bid)

except Exception:
continue

return list(biosamples)

except Exception as e:
logging.warning(f"{study_id} attempt {attempt + 1} failed: {e}")
time.sleep(2**attempt)

return []


def main():
spark = SparkSession.builder.appName("NMDC_Table2_Extraction").getOrCreate()
study_rdd = spark.sparkContext.parallelize(STUDIES, numSlices=2)

def process_study(study_id):
try:
biosample_ids = fetch_biosample_ids(study_id)
return [(study_id, bid) for bid in biosample_ids]
except Exception as e:
logging.error(f"Fatal error in {study_id}: {e}")
return []

rows_rdd = study_rdd.flatMap(process_study)

df = spark.createDataFrame(rows_rdd, schema=study_sample_schema)
df_final = df.dropDuplicates()

total = df_final.count()
logging.info(f"Total entries: {total}")

df_final.show(10, truncate=False)

df_final.write.mode("overwrite").parquet(OUTPUT_PATH)

logging.info(f"Saved data to {OUTPUT_PATH}")


if __name__ == "__main__":
main()
Loading
Loading