Skip to content

Commit 07f8b90

Browse files
carzeclaude
andcommitted
fix: fix IndentationError and missing pandas install in classify_virnucpro_contigs
- Add missing pip install for pandas - Dedent Python code to column 0 — shell heredocs do not strip leading whitespace, causing Python to raise IndentationError at module level Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent eaada37 commit 07f8b90

1 file changed

Lines changed: 98 additions & 97 deletions

File tree

pipes/WDL/tasks/tasks_metagenomics.wdl

Lines changed: 98 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,111 +1524,112 @@ task classify_virnucpro_contigs {
15241524

15251525
command <<<
15261526
set -e
1527+
pip install pandas --quiet --no-cache-dir
15271528
python3<<CODE
1528-
import re
1529-
import sys
1530-
import pandas as pd
1529+
import re
1530+
import sys
1531+
import pandas as pd
15311532
1532-
def classify_sequence(group, min_viral_proportion=0.1, min_nonviral_proportion=0.1, min_chunk_count=5):
1533-
group = group.copy()
1534-
group['delta'] = group['max_score_1'] - group['max_score_0']
1535-
group['confidence'] = group['delta'].abs().pow(0.5)
1533+
def classify_sequence(group, min_viral_proportion=0.1, min_nonviral_proportion=0.1, min_chunk_count=5):
1534+
group = group.copy()
1535+
group['delta'] = group['max_score_1'] - group['max_score_0']
1536+
group['confidence'] = group['delta'].abs().pow(0.5)
15361537
1537-
# Weighted mean delta (confidence-weighted)
1538-
if group['confidence'].sum() > 0:
1539-
weighted_delta = (group['delta'] * group['confidence']).sum() / group['confidence'].sum()
1538+
# Weighted mean delta (confidence-weighted)
1539+
if group['confidence'].sum() > 0:
1540+
weighted_delta = (group['delta'] * group['confidence']).sum() / group['confidence'].sum()
1541+
else:
1542+
weighted_delta = group['delta'].mean()
1543+
1544+
# Count high-confidence viral chunks
1545+
confident_viral = (group['max_score_1'] > 0.8) & (group['max_score_0'] < 0.3)
1546+
# Count high-confidence non-viral chunks
1547+
confident_nonviral = (group['max_score_0'] > 0.8) & (group['max_score_1'] < 0.3)
1548+
# Count ambiguous chunks
1549+
ambiguous = (group['max_score_1'] > 0.7) & (group['max_score_0'] > 0.7)
1550+
1551+
n_chunks = len(group)
1552+
n_confident_viral = confident_viral.sum()
1553+
n_confident_nonviral = confident_nonviral.sum()
1554+
n_ambiguous = ambiguous.sum()
1555+
# exclude ambiguous chunks from denominator so they don't dilute
1556+
# the proportion of confident viral/nonviral chunks
1557+
n_effective = n_chunks - n_ambiguous
1558+
viral_proportion = n_confident_viral / n_effective if n_effective > 0 else 0
1559+
nonviral_proportion = n_confident_nonviral / n_effective if n_effective > 0 else 0
1560+
1561+
# weighted_delta is the primary signal; chunk evidence determines tier
1562+
if weighted_delta > 0.3:
1563+
call = 'Viral'
1564+
if n_confident_viral >= 1 and viral_proportion >= min_viral_proportion:
1565+
tier = 'high_confidence' if weighted_delta > 0.6 else 'moderate_confidence'
15401566
else:
1541-
weighted_delta = group['delta'].mean()
1542-
1543-
# Count high-confidence viral chunks
1544-
confident_viral = (group['max_score_1'] > 0.8) & (group['max_score_0'] < 0.3)
1545-
# Count high-confidence non-viral chunks
1546-
confident_nonviral = (group['max_score_0'] > 0.8) & (group['max_score_1'] < 0.3)
1547-
# Count ambiguous chunks
1548-
ambiguous = (group['max_score_1'] > 0.7) & (group['max_score_0'] > 0.7)
1549-
1550-
n_chunks = len(group)
1551-
n_confident_viral = confident_viral.sum()
1552-
n_confident_nonviral = confident_nonviral.sum()
1553-
n_ambiguous = ambiguous.sum()
1554-
# Option 1: exclude ambiguous chunks from denominator so they don't dilute
1555-
# the proportion of confident viral/nonviral chunks
1556-
n_effective = n_chunks - n_ambiguous
1557-
viral_proportion = n_confident_viral / n_effective if n_effective > 0 else 0
1558-
nonviral_proportion = n_confident_nonviral / n_effective if n_effective > 0 else 0
1559-
1560-
# Option 3: weighted_delta is the primary signal; chunk evidence determines tier
1561-
if weighted_delta > 0.3:
1562-
call = 'Viral'
1563-
if n_confident_viral >= 1 and viral_proportion >= min_viral_proportion:
1564-
tier = 'high_confidence' if weighted_delta > 0.6 else 'moderate_confidence'
1565-
else:
1566-
tier = 'low_confidence'
1567-
elif weighted_delta < -0.3:
1568-
call = 'Non-viral'
1569-
if n_confident_nonviral >= 1 and nonviral_proportion >= min_nonviral_proportion:
1570-
tier = 'high_confidence' if weighted_delta < -0.6 else 'moderate_confidence'
1571-
else:
1572-
tier = 'low_confidence'
1567+
tier = 'low_confidence'
1568+
elif weighted_delta < -0.3:
1569+
call = 'Non-viral'
1570+
if n_confident_nonviral >= 1 and nonviral_proportion >= min_nonviral_proportion:
1571+
tier = 'high_confidence' if weighted_delta < -0.6 else 'moderate_confidence'
15731572
else:
1574-
call = 'Ambiguous'
1573+
tier = 'low_confidence'
1574+
else:
1575+
call = 'Ambiguous'
1576+
tier = 'review'
1577+
1578+
# Apply low chunk count penalty
1579+
if n_chunks < min_chunk_count:
1580+
if tier in ['high_confidence', 'moderate_confidence']:
1581+
tier = 'low_confidence'
1582+
elif tier == 'low_confidence':
15751583
tier = 'review'
1584+
return pd.Series({
1585+
'call': call,
1586+
'tier': tier,
1587+
'weighted_delta': round(weighted_delta, 3),
1588+
'n_chunks': n_chunks,
1589+
'n_confident_viral': n_confident_viral,
1590+
'n_confident_nonviral': n_confident_nonviral,
1591+
'n_ambiguous': n_ambiguous,
1592+
'viral_proportion': round(viral_proportion, 3),
1593+
'nonviral_proportion': round(nonviral_proportion, 3)
1594+
})
1595+
1596+
df = pd.read_csv("~{virnucpro_scores_tsv}", sep='\t')
1597+
1598+
required_cols = ["~{id_col}", 'max_score_0', 'max_score_1']
1599+
missing = [c for c in required_cols if c not in df.columns]
1600+
if missing:
1601+
print(f"Error: Missing required columns: {missing}", file=sys.stderr)
1602+
sys.exit(1)
15761603
1577-
# Apply low chunk count penalty
1578-
if n_chunks < min_chunk_count:
1579-
if tier in ['high_confidence', 'moderate_confidence']:
1580-
tier = 'low_confidence'
1581-
elif tier == 'low_confidence':
1582-
tier = 'review'
1583-
return pd.Series({
1584-
'call': call,
1585-
'tier': tier,
1586-
'weighted_delta': round(weighted_delta, 3),
1587-
'n_chunks': n_chunks,
1588-
'n_confident_viral': n_confident_viral,
1589-
'n_confident_nonviral': n_confident_nonviral,
1590-
'n_ambiguous': n_ambiguous,
1591-
'viral_proportion': round(viral_proportion, 3), # of effective (non-ambiguous) chunks
1592-
'nonviral_proportion': round(nonviral_proportion, 3) # of effective (non-ambiguous) chunks
1593-
})
1594-
1595-
df = pd.read_csv("~{virnucpro_scores_tsv}", sep='\t')
1596-
1597-
required_cols = ["~{id_col}", 'max_score_0', 'max_score_1']
1598-
missing = [c for c in required_cols if c not in df.columns]
1599-
if missing:
1600-
print(f"Error: Missing required columns: {missing}", file=sys.stderr)
1601-
sys.exit(1)
1602-
1603-
df['ID'] = df["~{id_col}"].str.replace(r'_chunk_\d+$', '', regex=True)
1604-
df['_group'] = df["~{id_col}"].str.extract("~{id_pattern}")
1604+
df['ID'] = df["~{id_col}"].str.replace(r'_chunk_\d+$', '', regex=True)
1605+
df['_group'] = df["~{id_col}"].str.extract("~{id_pattern}")
16051606
1606-
n_unmatched = df['_group'].isna().sum()
1607-
if n_unmatched == len(df):
1608-
print("Error: No valid NODE IDs extracted. Check id_pattern.", file=sys.stderr)
1609-
sys.exit(1)
1610-
elif n_unmatched > 0:
1611-
print(f"Warning: {n_unmatched} of {len(df)} rows did not match id_pattern and were excluded.", file=sys.stderr)
1612-
1613-
for col in ['max_score_0', 'max_score_1']:
1614-
if df[col].isna().any():
1615-
print(f"Warning: {df[col].isna().sum()} NaN values in {col}.", file=sys.stderr)
1616-
1617-
group_to_id = df.dropna(subset=['_group']).groupby('_group')['ID'].first()
1618-
results = df.groupby('_group').apply(
1619-
lambda g: classify_sequence(g, ~{min_viral_prop}, ~{min_nonviral_prop}, ~{min_chunks}),
1620-
include_groups=False
1621-
)
1622-
results = results.reset_index()
1623-
results['ID'] = results['_group'].map(group_to_id)
1624-
results = results.drop(columns=['_group'])
1625-
results = results[['ID'] + [c for c in results.columns if c != 'ID']]
1626-
1627-
def natural_sort_key(val):
1628-
return [int(s) if s.isdigit() else s.lower() for s in re.split(r'(\d+)', str(val))]
1629-
results = results.sort_values('ID', key=lambda col: col.map(natural_sort_key))
1630-
1631-
results.to_csv("~{out_filename}", sep='\t', index=False)
1607+
n_unmatched = df['_group'].isna().sum()
1608+
if n_unmatched == len(df):
1609+
print("Error: No valid NODE IDs extracted. Check id_pattern.", file=sys.stderr)
1610+
sys.exit(1)
1611+
elif n_unmatched > 0:
1612+
print(f"Warning: {n_unmatched} of {len(df)} rows did not match id_pattern and were excluded.", file=sys.stderr)
1613+
1614+
for col in ['max_score_0', 'max_score_1']:
1615+
if df[col].isna().any():
1616+
print(f"Warning: {df[col].isna().sum()} NaN values in {col}.", file=sys.stderr)
1617+
1618+
group_to_id = df.dropna(subset=['_group']).groupby('_group')['ID'].first()
1619+
results = df.groupby('_group').apply(
1620+
lambda g: classify_sequence(g, ~{min_viral_prop}, ~{min_nonviral_prop}, ~{min_chunks}),
1621+
include_groups=False
1622+
)
1623+
results = results.reset_index()
1624+
results['ID'] = results['_group'].map(group_to_id)
1625+
results = results.drop(columns=['_group'])
1626+
results = results[['ID'] + [c for c in results.columns if c != 'ID']]
1627+
1628+
def natural_sort_key(val):
1629+
return [int(s) if s.isdigit() else s.lower() for s in re.split(r'(\d+)', str(val))]
1630+
results = results.sort_values('ID', key=lambda col: col.map(natural_sort_key))
1631+
1632+
results.to_csv("~{out_filename}", sep='\t', index=False)
16321633
CODE
16331634
>>>
16341635

0 commit comments

Comments
 (0)