@@ -1524,111 +1524,112 @@ task classify_virnucpro_contigs {
15241524
15251525 command <<<
15261526 set -e
1527+ pip install pandas --quiet --no-cache-dir
15271528 python3 << CODE
1528- import re
1529- import sys
1530- import pandas as pd
1529+ import re
1530+ import sys
1531+ import pandas as pd
15311532
1532- def classify_sequence(group, min_viral_proportion=0.1, min_nonviral_proportion=0.1, min_chunk_count=5):
1533- group = group.copy()
1534- group['delta'] = group['max_score_1'] - group['max_score_0']
1535- group['confidence'] = group['delta'].abs().pow(0.5)
1533+ def classify_sequence(group, min_viral_proportion=0.1, min_nonviral_proportion=0.1, min_chunk_count=5):
1534+ group = group.copy()
1535+ group['delta'] = group['max_score_1'] - group['max_score_0']
1536+ group['confidence'] = group['delta'].abs().pow(0.5)
15361537
1537- # Weighted mean delta (confidence-weighted)
1538- if group['confidence'].sum() > 0:
1539- weighted_delta = (group['delta'] * group['confidence']).sum() / group['confidence'].sum()
1538+ # Weighted mean delta (confidence-weighted)
1539+ if group['confidence'].sum() > 0:
1540+ weighted_delta = (group['delta'] * group['confidence']).sum() / group['confidence'].sum()
1541+ else:
1542+ weighted_delta = group['delta'].mean()
1543+
1544+ # Count high-confidence viral chunks
1545+ confident_viral = (group['max_score_1'] > 0.8) & (group['max_score_0'] < 0.3)
1546+ # Count high-confidence non-viral chunks
1547+ confident_nonviral = (group['max_score_0'] > 0.8) & (group['max_score_1'] < 0.3)
1548+ # Count ambiguous chunks
1549+ ambiguous = (group['max_score_1'] > 0.7) & (group['max_score_0'] > 0.7)
1550+
1551+ n_chunks = len(group)
1552+ n_confident_viral = confident_viral.sum()
1553+ n_confident_nonviral = confident_nonviral.sum()
1554+ n_ambiguous = ambiguous.sum()
1555+ # exclude ambiguous chunks from denominator so they don't dilute
1556+ # the proportion of confident viral/nonviral chunks
1557+ n_effective = n_chunks - n_ambiguous
1558+ viral_proportion = n_confident_viral / n_effective if n_effective > 0 else 0
1559+ nonviral_proportion = n_confident_nonviral / n_effective if n_effective > 0 else 0
1560+
1561+ # weighted_delta is the primary signal; chunk evidence determines tier
1562+ if weighted_delta > 0.3:
1563+ call = 'Viral'
1564+ if n_confident_viral >= 1 and viral_proportion >= min_viral_proportion:
1565+ tier = 'high_confidence' if weighted_delta > 0.6 else 'moderate_confidence'
15401566 else:
1541- weighted_delta = group['delta'].mean()
1542-
1543- # Count high-confidence viral chunks
1544- confident_viral = (group['max_score_1'] > 0.8) & (group['max_score_0'] < 0.3)
1545- # Count high-confidence non-viral chunks
1546- confident_nonviral = (group['max_score_0'] > 0.8) & (group['max_score_1'] < 0.3)
1547- # Count ambiguous chunks
1548- ambiguous = (group['max_score_1'] > 0.7) & (group['max_score_0'] > 0.7)
1549-
1550- n_chunks = len(group)
1551- n_confident_viral = confident_viral.sum()
1552- n_confident_nonviral = confident_nonviral.sum()
1553- n_ambiguous = ambiguous.sum()
1554- # Option 1: exclude ambiguous chunks from denominator so they don't dilute
1555- # the proportion of confident viral/nonviral chunks
1556- n_effective = n_chunks - n_ambiguous
1557- viral_proportion = n_confident_viral / n_effective if n_effective > 0 else 0
1558- nonviral_proportion = n_confident_nonviral / n_effective if n_effective > 0 else 0
1559-
1560- # Option 3: weighted_delta is the primary signal; chunk evidence determines tier
1561- if weighted_delta > 0.3:
1562- call = 'Viral'
1563- if n_confident_viral >= 1 and viral_proportion >= min_viral_proportion:
1564- tier = 'high_confidence' if weighted_delta > 0.6 else 'moderate_confidence'
1565- else:
1566- tier = 'low_confidence'
1567- elif weighted_delta < -0.3:
1568- call = 'Non-viral'
1569- if n_confident_nonviral >= 1 and nonviral_proportion >= min_nonviral_proportion:
1570- tier = 'high_confidence' if weighted_delta < -0.6 else 'moderate_confidence'
1571- else:
1572- tier = 'low_confidence'
1567+ tier = 'low_confidence'
1568+ elif weighted_delta < -0.3:
1569+ call = 'Non-viral'
1570+ if n_confident_nonviral >= 1 and nonviral_proportion >= min_nonviral_proportion:
1571+ tier = 'high_confidence' if weighted_delta < -0.6 else 'moderate_confidence'
15731572 else:
1574- call = 'Ambiguous'
1573+ tier = 'low_confidence'
1574+ else:
1575+ call = 'Ambiguous'
1576+ tier = 'review'
1577+
1578+ # Apply low chunk count penalty
1579+ if n_chunks < min_chunk_count:
1580+ if tier in ['high_confidence', 'moderate_confidence']:
1581+ tier = 'low_confidence'
1582+ elif tier == 'low_confidence':
15751583 tier = 'review'
1584+ return pd.Series({
1585+ 'call': call,
1586+ 'tier': tier,
1587+ 'weighted_delta': round(weighted_delta, 3),
1588+ 'n_chunks': n_chunks,
1589+ 'n_confident_viral': n_confident_viral,
1590+ 'n_confident_nonviral': n_confident_nonviral,
1591+ 'n_ambiguous': n_ambiguous,
1592+ 'viral_proportion': round(viral_proportion, 3),
1593+ 'nonviral_proportion': round(nonviral_proportion, 3)
1594+ })
1595+
1596+ df = pd.read_csv("~{virnucpro_scores_tsv }", sep='\t')
1597+
1598+ required_cols = ["~{id_col }", 'max_score_0', 'max_score_1']
1599+ missing = [c for c in required_cols if c not in df.columns]
1600+ if missing:
1601+ print(f"Error: Missing required columns: {missing}", file=sys.stderr)
1602+ sys.exit(1)
15761603
1577- # Apply low chunk count penalty
1578- if n_chunks < min_chunk_count:
1579- if tier in ['high_confidence', 'moderate_confidence']:
1580- tier = 'low_confidence'
1581- elif tier == 'low_confidence':
1582- tier = 'review'
1583- return pd.Series({
1584- 'call': call,
1585- 'tier': tier,
1586- 'weighted_delta': round(weighted_delta, 3),
1587- 'n_chunks': n_chunks,
1588- 'n_confident_viral': n_confident_viral,
1589- 'n_confident_nonviral': n_confident_nonviral,
1590- 'n_ambiguous': n_ambiguous,
1591- 'viral_proportion': round(viral_proportion, 3), # of effective (non-ambiguous) chunks
1592- 'nonviral_proportion': round(nonviral_proportion, 3) # of effective (non-ambiguous) chunks
1593- })
1594-
1595- df = pd.read_csv("~{virnucpro_scores_tsv }", sep='\t')
1596-
1597- required_cols = ["~{id_col }", 'max_score_0', 'max_score_1']
1598- missing = [c for c in required_cols if c not in df.columns]
1599- if missing:
1600- print(f"Error: Missing required columns: {missing}", file=sys.stderr)
1601- sys.exit(1)
1602-
1603- df['ID'] = df["~{id_col }"].str.replace(r'_chunk_\d+$', '', regex=True)
1604- df['_group'] = df["~{id_col }"].str.extract("~{id_pattern }")
1604+ df['ID'] = df["~{id_col }"].str.replace(r'_chunk_\d+$', '', regex=True)
1605+ df['_group'] = df["~{id_col }"].str.extract("~{id_pattern }")
16051606
1606- n_unmatched = df['_group'].isna().sum()
1607- if n_unmatched == len(df):
1608- print("Error: No valid NODE IDs extracted. Check id_pattern.", file=sys.stderr)
1609- sys.exit(1)
1610- elif n_unmatched > 0:
1611- print(f"Warning: {n_unmatched} of {len(df)} rows did not match id_pattern and were excluded.", file=sys.stderr)
1612-
1613- for col in ['max_score_0', 'max_score_1']:
1614- if df[col].isna().any():
1615- print(f"Warning: {df[col].isna().sum()} NaN values in {col}.", file=sys.stderr)
1616-
1617- group_to_id = df.dropna(subset=['_group']).groupby('_group')['ID'].first()
1618- results = df.groupby('_group').apply(
1619- lambda g: classify_sequence(g, ~{min_viral_prop }, ~{min_nonviral_prop }, ~{min_chunks }),
1620- include_groups=False
1621- )
1622- results = results.reset_index()
1623- results['ID'] = results['_group'].map(group_to_id)
1624- results = results.drop(columns=['_group'])
1625- results = results[['ID'] + [c for c in results.columns if c != 'ID']]
1626-
1627- def natural_sort_key(val):
1628- return [int(s) if s.isdigit() else s.lower() for s in re.split(r'(\d+)', str(val))]
1629- results = results.sort_values('ID', key=lambda col: col.map(natural_sort_key))
1630-
1631- results.to_csv("~{out_filename }", sep='\t', index=False)
1607+ n_unmatched = df['_group'].isna().sum()
1608+ if n_unmatched == len(df):
1609+ print("Error: No valid NODE IDs extracted. Check id_pattern.", file=sys.stderr)
1610+ sys.exit(1)
1611+ elif n_unmatched > 0:
1612+ print(f"Warning: {n_unmatched} of {len(df)} rows did not match id_pattern and were excluded.", file=sys.stderr)
1613+
1614+ for col in ['max_score_0', 'max_score_1']:
1615+ if df[col].isna().any():
1616+ print(f"Warning: {df[col].isna().sum()} NaN values in {col}.", file=sys.stderr)
1617+
1618+ group_to_id = df.dropna(subset=['_group']).groupby('_group')['ID'].first()
1619+ results = df.groupby('_group').apply(
1620+ lambda g: classify_sequence(g, ~{min_viral_prop }, ~{min_nonviral_prop }, ~{min_chunks }),
1621+ include_groups=False
1622+ )
1623+ results = results.reset_index()
1624+ results['ID'] = results['_group'].map(group_to_id)
1625+ results = results.drop(columns=['_group'])
1626+ results = results[['ID'] + [c for c in results.columns if c != 'ID']]
1627+
1628+ def natural_sort_key(val):
1629+ return [int(s) if s.isdigit() else s.lower() for s in re.split(r'(\d+)', str(val))]
1630+ results = results.sort_values('ID', key=lambda col: col.map(natural_sort_key))
1631+
1632+ results.to_csv("~{out_filename }", sep='\t', index=False)
16321633CODE
16331634 >>>
16341635
0 commit comments