diff --git a/.claude/skills/new-subcmd/SKILL.md b/.claude/skills/new-subcmd/SKILL.md new file mode 100644 index 0000000..8180523 --- /dev/null +++ b/.claude/skills/new-subcmd/SKILL.md @@ -0,0 +1,69 @@ +--- +name: new-subcmd +description: mdtbxに新しいargparseサブコマンドを追加する +--- + +# new-subcmd スキル + +mdtbxに新しいargparseサブコマンドを追加するスキル。 + +## ワークフロー + +### ステップ1: ユーザーへの確認 + +以下をユーザーに確認する: +1. **サブコマンド名** (例: `calc_rmsd`, `gen_topology`) — CLIで `mdtbx <名前>` として呼ぶ +2. **配置場所**: 機能に応じて選択 + - `src/build/` — 系構築 (addace, amb2gro, gen_posres等と同種) + - `src/trajectory/` — 軌跡処理 (fit, trjcat等と同種) + - `src/analysis/` — 解析 (extract_str等と同種) + - `src/cv/` — Collective Variable計算 + - `src/utils/` — 汎用ユーティリティ (mod_mdp, convert等と同種) +3. **引数**: 必須引数・オプション引数の名前と型 +4. **処理内容**: 何をするコマンドか + +### ステップ2: サブコマンドファイルの作成 + +`${SKILL_ROOT}/template.py` を参考に新ファイルを作成する。 + +必須パターン: +- `add_subcmd(subparsers)` の末尾に必ず `parser.set_defaults(func=run)` を置く +- `run(args)` が実装本体 +- ロガーは `from ..logger import generate_logger` / `LOGGER = generate_logger(__name__)` +- `argparse.ArgumentDefaultsHelpFormatter` を使う +- `src/utils/` のパーサーを使う場合は `from ..utils.atom_selection_parser import AtomSelector` のように参照する + +### ステップ3: cli.py への登録 + +`src/cli.py` に2箇所追加: + +1. **importブロック** (対応するカテゴリのブロックに追加): + ```python + from .build import # build の場合 + from .trajectory import # trajectory の場合 + from .analysis import # analysis の場合 + from .cv import # cv の場合 + from .utils import # utils の場合 + ``` + +2. **add_subcmdの呼び出し** (同カテゴリのブロックに追加): + ```python + .add_subcmd(subparsers) + ``` + +### ステップ4: テストファイルの作成 + +`tests/test_/test_.py` にユニットテストを追加する: +- 外部ツール不要な純粋計算関数は直接テスト +- ファイルI/Oは `tmp_path` fixture を使用 +- 軌跡が必要な場合は `conftest.py` の `trajectory_files` fixture を使用 +- PyMOL/subprocess 依存は `unittest.mock.patch` でモック + +### ステップ5: Lintチェック・テスト実行 + +```bash +pixi run r # ruff format + lint +pixi run test # 全テスト +``` + +エラーがあれば修正する。 diff --git a/.claude/skills/new-subcmd/template.py b/.claude/skills/new-subcmd/template.py new file mode 100644 index 0000000..c94d394 --- /dev/null +++ b/.claude/skills/new-subcmd/template.py @@ -0,0 +1,24 @@ +import argparse + +from ..logger import generate_logger + +LOGGER = generate_logger(__name__) + + +def add_subcmd(subparsers): + parser = subparsers.add_parser( + "SUBCMD_NAME", + help="HELP_TEXT", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + # 必須引数 + parser.add_argument("input", type=str, help="Input file") + # オプション引数 + parser.add_argument("--output", type=str, default="output.dat", help="Output file") + + parser.set_defaults(func=run) + + +def run(args): + LOGGER.info(f"input: {args.input}") + # TODO: 実装 diff --git a/.gitignore b/.gitignore index 8fdb2d0..b67259f 100644 --- a/.gitignore +++ b/.gitignore @@ -196,3 +196,6 @@ cython_debug/ .pixi !/sample/build +!/src/build + +TODO.md diff --git a/AGENTS.md b/AGENTS.md new file mode 120000 index 0000000..681311e --- /dev/null +++ b/AGENTS.md @@ -0,0 +1 @@ +CLAUDE.md \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..e1eaba8 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,122 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Overview + +`mdtbx` はMDシミュレーション用のツールボックス。系の構築・シミュレーション実行・軌跡解析・自由エネルギー計算をサポートするCLIツール。 + +依存ツール: AMBER, PyMOL, OpenBabel, Gromacs, Gaussian16 +力場: ff14SB, TIP3P, GAFF2, Lipid21, GLYCAM06-j + +## 開発コマンド + +```bash +# 環境構築 +pixi install + +# CLIの実行 +pixi run mdtbx +pixi run gmx ... # mdtbx cmd gmx ... の短縮形 + +# テスト +pixi run test # 全テスト実行 +pixi run test-fast # 最初の失敗で停止 (-x) + +# コードフォーマット・Lint +pixi run r # ruff format + ruff check を一括実行 +pixi run ruff-format # フォーマットのみ +pixi run ruff-lint # Lintのみ + +# 更新 +pixi run update # git pull && pixi install + +# PyMOL設定 +pixi run pymolrc # ~/.pymolrcを生成 + +# JupyterLab (リモート) +pixi run jupyter_remote +``` + +## アーキテクチャ + +``` +src/ + __main__.py # エントリポイント: main() -> cli() + cli.py # argparseサブコマンドの登録・ディスパッチ + config.py # グローバル定数(水密度、Gaussian設定、MAXWARN等) + logger.py # ロガー生成ユーティリティ + utils/ # 汎用ユーティリティ(mod_mdp, convert, rmfile, cmd, shell_hook, show_mdtraj, show_npy, partial_tempering) + # ※ atom_selection_parser.py, parse_top.py はサブコマンドでなくライブラリユーティリティ + build/ # 系構築サブコマンド(addace, addnme, add_ndx, mv_crds_mol2, calc_ion_conc, centering_gro, + # find_bond, gen_am1bcc, gen_resp, gen_modres_am1bcc, gen_modres_resp, gen_posres, + # gen_distres, modeling_cf, amb2gro, build_solution, gen_temperatures) + # ※ addh.py, mutate.py, place_solvent.py は未登録(cli.pyへの追加が必要) + trajectory/ # 軌跡処理サブコマンド(fit, trjcat, pacs_trjcat, print_perf) + # ※ opt_perf.py は未登録 + analysis/ # 解析サブコマンド(extract_str, extract_ave_str) + # ※ contactmap.py, distmat.py は未登録 + cv/ # Collective Variable計算(comdist, comvec, densmap, mindist, rmsd, rmsf, xyz, pca) + +tests/ + conftest.py # 共有fixture・PyMOLモック設定 + fixtures/ # テストデータ(sample.mdp, sample.top, sample.pdb) + test_utils/ # src/utils/ のテスト + test_build/ # src/build/ のテスト + test_trajectory/ # src/trajectory/ のテスト + test_analysis/ # src/analysis/ のテスト + test_cv/ # src/cv/ のテスト + test_cli.py # 全サブコマンドのCLI登録確認 + +pymol-plugins/ + pymol_plugins/ # PyMOLプラグイン(builder, visualizer, selector等) + +example/ # 用途別のサンプルノートブック・スクリプト +install_scripts/ # Gromacs/PLUMED等の手動インストールスクリプト +``` + +### サブコマンドの追加パターン + +各モジュール(`src/build/*.py`, `src/trajectory/*.py`, `src/analysis/*.py`, `src/cv/*.py`, `src/utils/*.py`)は以下の2関数を実装する: + +```python +def add_subcmd(subparsers): + # argparse サブコマンドの引数定義 + +def run(args): + # 実装本体 +``` + +`cli.py` に以下の2箇所を追加して登録する: + +```python +# 1. importブロック (カテゴリに応じて選択) +from .build import # 系構築 +from .trajectory import # 軌跡処理 +from .analysis import # 解析 +from .cv import # CV計算 +from .utils import # 汎用 + +# 2. add_subcmdの呼び出し +.add_subcmd(subparsers) +``` + +モジュール内で `src/utils/` のパーサーを使う場合は `..utils.` で参照する: +```python +from ..utils.atom_selection_parser import AtomSelector +from ..utils.parse_top import GromacsTopologyParser +``` + +### 設定 (`src/config.py`) + +- `MAXWARN`: grompp の最大警告数 +- `GAUSSIAN_CMD`, `STRUCTURE_OPTIMIZATION`, `SINGLE_POINT_CALCULATION`: Gaussian設定 +- 各水モデル(TIP3P/TIP4P/TIP5P/OPC)の密度・体積定数 +- 起動時に `.pixi/envs/default/bin` をPATHに追加する + +## 環境管理 + +- パッケージ管理: `pixi`(conda + pip の混在) +- Python: 3.10固定 +- `pixi.lock` で再現性を保証 +- Docker対応: `Dockerfile` でコンテナビルドも可能 diff --git a/example/build/all_atom/protein.sh b/example/build/all_atom/protein.sh index b2c8138..a3c65be 100755 --- a/example/build/all_atom/protein.sh +++ b/example/build/all_atom/protein.sh @@ -1,4 +1,30 @@ #!/bin/bash set -e +input_structure="input.pdb" +mdtbx addh -s ${input_structure} -o input_h + +mdtbx addace -s input_h.pdb -o ace + +mdtbx addnme -s ace.pdb -o ace_nme + +mdtbx find_bond -s ace_nme.pdb -o bonds.txt -op cym.pdb + +# SS-bond有無で入力PDBとpostcmdを切り替え +if [ -s bonds.txt ]; then + input_pdb="cym.pdb" + mdtbx build_vacuum \ + -i ${input_pdb} \ + -o ./ \ + --addpostcmd "$(cat bonds.txt)" +else + input_pdb="ace_nme.pdb" + mdtbx build_vacuum \ + -i ${input_pdb} \ + -o ./ +fi + +rm -f input_h.pdb ace.pdb ace_nme.pdb cym.pdb bonds.txt + +echo done diff --git a/example/build/all_atom/protein_water.sh b/example/build/all_atom/protein_water.sh index b2c8138..a2f3b17 100755 --- a/example/build/all_atom/protein_water.sh +++ b/example/build/all_atom/protein_water.sh @@ -1,4 +1,58 @@ #!/bin/bash set -e +input_structure="input.pdb" +out_dir="${PWD}/gmx" +mdtbx addh -s ${input_structure} -o input_h + +mdtbx addace -s input_h.pdb -o ace + +mdtbx addnme -s ace.pdb -o ace_nme + +mdtbx find_bond -s ace_nme.pdb -o bonds.txt -op cym.pdb + +# SS-bond有無で入力PDBとpostcmdを切り替え +if [ -s bonds.txt ]; then + input_pdb="cym.pdb" + mdtbx build_solution \ + -i ${input_pdb} \ + -o ./ \ + --ion_conc 0.15 \ + --cation Na+ \ + --anion Cl- \ + --boxsize 100 100 100 \ + --addpostcmd "$(cat bonds.txt)" +else + input_pdb="ace_nme.pdb" + mdtbx build_solution \ + -i ${input_pdb} \ + -o ./ \ + --ion_conc 0.15 \ + --cation Na+ \ + --anion Cl- \ + --boxsize 100 100 100 +fi + +mdtbx amb2gro -p leap.parm7 -x leap.rst7 --type parmed + +mdtbx add_ndx -g gmx.gro + +mdtbx centering_gro -f gmx.gro -p gmx.top -c Protein + +mdtbx gen_posres -p gmx.top -s "protein and backbone" -o posres + +mdtbx rmfile + +mkdir ${out_dir} +mv gmx.gro ${out_dir}/ +mv gmx.top ${out_dir}/ +mv *itp ${out_dir}/ +mv *.ndx ${out_dir}/ +cp mdps/*.mdp ${out_dir}/ +cp mdrun_slurm.sh ${out_dir}/ + +rm -f leap.parm7 leap.rst7 leap.pdb gmx.pdb +rm -f input_h.pdb ace.pdb ace_nme.pdb cym.pdb bonds.txt + +echo done diff --git a/example/build/all_atom/protein_water_ligand_membrane_glycan.sh b/example/build/all_atom/protein_water_ligand_membrane_glycan.sh index b2c8138..d5f35e1 100755 --- a/example/build/all_atom/protein_water_ligand_membrane_glycan.sh +++ b/example/build/all_atom/protein_water_ligand_membrane_glycan.sh @@ -1,4 +1,66 @@ #!/bin/bash set -e +# グリカンがタンパク質と同一PDBに含まれている前提 +input_structure="prot_lig.pdb" +out_dir="${PWD}/gmx" +lig_frcmod="./LIG.frcmod" +lig_lib="./LIG.lib" +mdtbx addh -s ${input_structure} -o input_h + +mdtbx addace -s input_h.pdb -o ace + +mdtbx addnme -s ace.pdb -o ace_nme + +mdtbx find_bond -s ace_nme.pdb -o bonds.txt -op cym.pdb + +# SS-bond有無で入力PDBを切り替え +if [ -s bonds.txt ]; then + input_pdb="cym.pdb" +else + input_pdb="ace_nme.pdb" +fi + +# --keepligs でグリカン座標を保持、--gaff2 は使わずGLYCAM06-jをtleapでロード +mdtbx cmd packmol-memgen \ + --pdb ${input_pdb} \ + --lipids POPC:CHL1 \ + --ratio 4:1 \ + --salt \ + --salt_c Na+ \ + --salt_a Cl- \ + --saltcon 0.15 \ + --keepligs \ + --notprotonate \ + --dims 120 120 200 \ + --ffwat tip3p \ + --ffprot ff14SB \ + --fflip lipid21 \ + --ligand_param ${lig_frcmod}:${lig_lib} \ + --leapline "source leaprc.GLYCAM_06j-1" + +mdtbx amb2gro -p bilayer_${input_pdb%.pdb}_lipid.top -x bilayer_${input_pdb%.pdb}_lipid.crd --type parmed + +mdtbx add_ndx -g gmx.gro + +mdtbx centering_gro -f gmx.gro -p gmx.top -c Protein + +mdtbx gen_posres -p gmx.top -s "protein and backbone" -o posres + +mdtbx rmfile + +mkdir ${out_dir} +mv gmx.gro ${out_dir}/ +mv gmx.top ${out_dir}/ +mv *itp ${out_dir}/ +mv *.ndx ${out_dir}/ +cp mdps/*.mdp ${out_dir}/ +cp mdrun_slurm.sh ${out_dir}/ + +rm -f input_h.pdb ace.pdb ace_nme.pdb cym.pdb bonds.txt \ + bilayer_*.pdb bilayer_*.pdb_FORCED bilayer_*.crd bilayer_*.top \ + *in_EMBED.pdb *in_memembed.log \ + leap_.log packmol-memgen.log packmol.inp packmol.log gmx.pdb + +echo done diff --git a/example/build/all_atom/protein_water_membrane.sh b/example/build/all_atom/protein_water_membrane.sh index b2c8138..0e9ad39 100755 --- a/example/build/all_atom/protein_water_membrane.sh +++ b/example/build/all_atom/protein_water_membrane.sh @@ -1,4 +1,63 @@ #!/bin/bash set -e +input_structure="input.pdb" +out_dir="${PWD}/gmx" +mdtbx addh -s ${input_structure} -o input_h + +mdtbx addace -s input_h.pdb -o ace + +mdtbx addnme -s ace.pdb -o ace_nme + +mdtbx find_bond -s ace_nme.pdb -o bonds.txt -op cym.pdb + +# SS-bond有無で入力PDBを切り替え +if [ -s bonds.txt ]; then + input_pdb="cym.pdb" +else + input_pdb="ace_nme.pdb" +fi + +mdtbx cmd packmol-memgen \ + --pdb ${input_pdb} \ + --lipids POPC:CHL1 \ + --ratio 4:1 \ + --salt \ + --salt_c Na+ \ + --salt_a Cl- \ + --saltcon 0.15 \ + --keepligs \ + --notprotonate \ + --dims 120 120 200 \ + --ffwat tip3p \ + --ffprot ff14SB \ + --fflip lipid21 + +# SS-bondがある場合はtleap時にpostcmdを追加 +# packmol-memgenが生成するtopのtleap入力はそのままでは使えないため +# amb2groで変換後の.parm7を使う +mdtbx amb2gro -p bilayer_${input_pdb%.pdb}_lipid.top -x bilayer_${input_pdb%.pdb}_lipid.crd --type parmed + +mdtbx add_ndx -g gmx.gro + +mdtbx centering_gro -f gmx.gro -p gmx.top -c Protein + +mdtbx gen_posres -p gmx.top -s "protein and backbone" -o posres + +mdtbx rmfile + +mkdir ${out_dir} +mv gmx.gro ${out_dir}/ +mv gmx.top ${out_dir}/ +mv *itp ${out_dir}/ +mv *.ndx ${out_dir}/ +cp mdps/*.mdp ${out_dir}/ +cp mdrun_slurm.sh ${out_dir}/ + +rm -f input_h.pdb ace.pdb ace_nme.pdb cym.pdb bonds.txt \ + bilayer_*.pdb bilayer_*.pdb_FORCED bilayer_*.crd bilayer_*.top \ + *in_EMBED.pdb *in_memembed.log \ + leap_.log packmol-memgen.log packmol.inp packmol.log gmx.pdb + +echo done diff --git a/example/build/all_atom/protein_water_membrane_glycan.sh b/example/build/all_atom/protein_water_membrane_glycan.sh index b2c8138..2986136 100755 --- a/example/build/all_atom/protein_water_membrane_glycan.sh +++ b/example/build/all_atom/protein_water_membrane_glycan.sh @@ -1,4 +1,63 @@ #!/bin/bash set -e +# グリカンがタンパク質と同一PDBに含まれている前提 +input_structure="input.pdb" +out_dir="${PWD}/gmx" +mdtbx addh -s ${input_structure} -o input_h + +mdtbx addace -s input_h.pdb -o ace + +mdtbx addnme -s ace.pdb -o ace_nme + +mdtbx find_bond -s ace_nme.pdb -o bonds.txt -op cym.pdb + +# SS-bond有無で入力PDBを切り替え +if [ -s bonds.txt ]; then + input_pdb="cym.pdb" +else + input_pdb="ace_nme.pdb" +fi + +# --keepligs でグリカン座標を保持、--gaff2 は使わずGLYCAM06-jをtleapでロード +mdtbx cmd packmol-memgen \ + --pdb ${input_pdb} \ + --lipids POPC:CHL1 \ + --ratio 4:1 \ + --salt \ + --salt_c Na+ \ + --salt_a Cl- \ + --saltcon 0.15 \ + --keepligs \ + --notprotonate \ + --dims 120 120 200 \ + --ffwat tip3p \ + --ffprot ff14SB \ + --fflip lipid21 \ + --leapline "source leaprc.GLYCAM_06j-1" + +mdtbx amb2gro -p bilayer_${input_pdb%.pdb}_lipid.top -x bilayer_${input_pdb%.pdb}_lipid.crd --type parmed + +mdtbx add_ndx -g gmx.gro + +mdtbx centering_gro -f gmx.gro -p gmx.top -c Protein + +mdtbx gen_posres -p gmx.top -s "protein and backbone" -o posres + +mdtbx rmfile + +mkdir ${out_dir} +mv gmx.gro ${out_dir}/ +mv gmx.top ${out_dir}/ +mv *itp ${out_dir}/ +mv *.ndx ${out_dir}/ +cp mdps/*.mdp ${out_dir}/ +cp mdrun_slurm.sh ${out_dir}/ + +rm -f input_h.pdb ace.pdb ace_nme.pdb cym.pdb bonds.txt \ + bilayer_*.pdb bilayer_*.pdb_FORCED bilayer_*.crd bilayer_*.top \ + *in_EMBED.pdb *in_memembed.log \ + leap_.log packmol-memgen.log packmol.inp packmol.log gmx.pdb + +echo done diff --git a/example/cv/comdist.sh b/example/cv/comdist.sh new file mode 100755 index 0000000..190f82b --- /dev/null +++ b/example/cv/comdist.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# comdist.sh +# 2つの原子グループ間の重心距離 (COM distance) を計算する +# +# 出力: .npy (shape: [n_frames], 単位: nm) +# +# 使用例: +# bash comdist.sh + +set -e + +TOPOLOGY="gmx.gro" +TRAJECTORY="prd.xtc" + +mkdir -p cvs + +# ----------------------------------------------------------------------- +# 基本ケース: MDtraj による計算 +# protein と リガンド間の重心距離 +# ----------------------------------------------------------------------- +mdtbx comdist \ + -p ${TOPOLOGY} \ + -t ${TRAJECTORY} \ + -s1 "protein" \ + -s2 "resname LIG" \ + -o cvs/comdist.npy + +echo "comdist done -> cvs/comdist.npy" + +# ----------------------------------------------------------------------- +# Gromacs インターフェース (--gmx): 大規模系で高速 +# -s1/-s2 に ndx グループ名を指定する +# ----------------------------------------------------------------------- +# mdtbx comdist \ +# -p gmx.tpr \ +# -t ${TRAJECTORY} \ +# -s1 "Protein" \ +# -s2 "LIG" \ +# --gmx \ +# -idx gmx.ndx \ +# -o cvs/comdist_gmx.npy + +echo "All done." diff --git a/example/cv/comvec.sh b/example/cv/comvec.sh new file mode 100755 index 0000000..2e27d07 --- /dev/null +++ b/example/cv/comvec.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# comvec.sh +# 2つの原子グループ間の重心ベクトル (COM vector) を計算する +# +# comdist がスカラー(距離)を返すのに対し、comvec は3次元ベクトルを返す +# チャネル透過や膜挿入のような方向性のある運動の解析に適している +# +# 出力: .npy (shape: [n_frames, 3], 単位: nm) +# +# 使用例: +# bash comvec.sh + +set -e + +TOPOLOGY="gmx.gro" +TRAJECTORY="prd.xtc" + +mkdir -p cvs + +# ----------------------------------------------------------------------- +# 基本ケース: MDtraj による計算 +# ----------------------------------------------------------------------- +mdtbx comvec \ + -p ${TOPOLOGY} \ + -t ${TRAJECTORY} \ + -s1 "protein" \ + -s2 "resname LIG" \ + -o cvs/comvec.npy + +echo "comvec done -> cvs/comvec.npy" + +# ----------------------------------------------------------------------- +# Gromacs インターフェース (--gmx) +# ----------------------------------------------------------------------- +# mdtbx comvec \ +# -p gmx.tpr \ +# -t ${TRAJECTORY} \ +# -s1 "Protein" \ +# -s2 "LIG" \ +# --gmx \ +# -idx gmx.ndx \ +# -o cvs/comvec_gmx.npy + +echo "All done." diff --git a/example/cv/densmap.sh b/example/cv/densmap.sh new file mode 100755 index 0000000..a78713c --- /dev/null +++ b/example/cv/densmap.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# densmap.sh +# 特定の原子群の 2D 密度マップ (ヒストグラム) を計算する +# +# 膜系でのリガンド分布や、特定平面への投影密度の可視化に使用する +# --axis: 投影する平面 (xy / xz / yz) +# --bins: 各軸のビン数 +# +# 出力: .npy (object array [counts, edges0, edges1]) +# counts : shape [bins, bins] - 各セルの頻度 +# edges0/1: shape [bins+1] - ビンの境界 +# +# 使用例: +# bash densmap.sh + +set -e + +TOPOLOGY="gmx.gro" +TRAJECTORY="prd.xtc" + +mkdir -p cvs + +# ----------------------------------------------------------------------- +# xy 平面への投影 (膜面内の分布) +# ----------------------------------------------------------------------- +mdtbx densmap \ + -p ${TOPOLOGY} \ + -t ${TRAJECTORY} \ + -s "resname LIG" \ + --axis xy \ + --bins 100 \ + -o cvs/densmap_xy.npy + +echo "densmap xy done -> cvs/densmap_xy.npy" + +# ----------------------------------------------------------------------- +# xz 平面への投影 (膜の厚さ方向を含む断面) +# ----------------------------------------------------------------------- +mdtbx densmap \ + -p ${TOPOLOGY} \ + -t ${TRAJECTORY} \ + -s "resname LIG" \ + --axis xz \ + --bins 100 \ + -o cvs/densmap_xz.npy + +echo "densmap xz done -> cvs/densmap_xz.npy" + +# ----------------------------------------------------------------------- +# Gromacs gmx densmap を使う場合 (--gmx) +# ----------------------------------------------------------------------- +# mdtbx densmap \ +# -p gmx.tpr \ +# -t ${TRAJECTORY} \ +# -s "LIG" \ +# --gmx \ +# -idx gmx.ndx \ +# -o cvs/densmap_gmx.npy + +echo "All done." diff --git a/example/cv/mindist.sh b/example/cv/mindist.sh new file mode 100755 index 0000000..0e6e067 --- /dev/null +++ b/example/cv/mindist.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# mindist.sh +# 2つの原子グループ間の全ペアから最小距離を計算する +# +# comdist (重心距離) と異なり、2グループの接触の有無を反映しやすい +# 結合部位での接触距離や、タンパク質-タンパク質界面の解析に使用する +# +# 出力: .npy (shape: [n_frames], 単位: nm) +# +# 使用例: +# bash mindist.sh + +set -e + +TOPOLOGY="gmx.gro" +TRAJECTORY="prd.xtc" + +mkdir -p cvs + +# ----------------------------------------------------------------------- +# 基本ケース: 活性部位残基とリガンド間の最小距離 +# ----------------------------------------------------------------------- +mdtbx mindist \ + -p ${TOPOLOGY} \ + -t ${TRAJECTORY} \ + -s1 "resid 50 to 70" \ + -s2 "resname LIG" \ + -o cvs/mindist.npy + +echo "mindist done -> cvs/mindist.npy" + +# ----------------------------------------------------------------------- +# タンパク質-タンパク質界面の最小距離 +# ----------------------------------------------------------------------- +# mdtbx mindist \ +# -p ${TOPOLOGY} \ +# -t ${TRAJECTORY} \ +# -s1 "chainid 0" \ +# -s2 "chainid 1" \ +# -o cvs/mindist_pp.npy + +echo "All done." diff --git a/example/cv/pca.sh b/example/cv/pca.sh new file mode 100755 index 0000000..98f342a --- /dev/null +++ b/example/cv/pca.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# pca.sh +# 主成分分析 (PCA) で集団運動モードを抽出する +# +# 自由エネルギー地形の可視化や、PACs MD の多次元 CV として使用する +# フィット選択 (-sft/-sfr) と PCA 計算選択 (-sct/-scr) を分けられる +# +# 出力: .npy (shape: [n_frames, n_components], 単位: nm) +# +# 使用例: +# bash pca.sh + +set -e + +TOPOLOGY="gmx.gro" +TRAJECTORY="prd.xtc" +REFERENCE="ref.gro" + +mkdir -p cvs + +# ----------------------------------------------------------------------- +# 基本ケース: MDtraj + scikit-learn による PCA +# backbone で重ね合わせ後、backbone の主成分を抽出 +# ----------------------------------------------------------------------- +mdtbx pca \ + -p ${TOPOLOGY} \ + -t ${TRAJECTORY} \ + -r ${REFERENCE} \ + -sft "protein and backbone" \ + -sfr "protein and backbone" \ + -sct "protein and backbone" \ + -scr "protein and backbone" \ + -n 10 \ + -o cvs/pca.npy + +echo "pca done -> cvs/pca.npy (shape: [n_frames, 10])" + +# ----------------------------------------------------------------------- +# Gromacs gmx covar/anaeig を使う場合 (--gmx) +# 大規模系や gmx との一貫性が必要な場合に使用する +# ----------------------------------------------------------------------- +# mdtbx pca \ +# -p gmx.tpr \ +# -t ${TRAJECTORY} \ +# -sft "Backbone" \ +# -sct "Backbone" \ +# --gmx \ +# -idx gmx.ndx \ +# -n 10 \ +# -o cvs/pca_gmx.npy + +echo "All done." diff --git a/example/cv/rmsd.sh b/example/cv/rmsd.sh new file mode 100755 index 0000000..9a27f05 --- /dev/null +++ b/example/cv/rmsd.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# rmsd.sh +# reference 構造に対する RMSD を計算する +# +# フィット選択 (-sft/-sfr) と RMSD 計算選択 (-sct/-scr) を分けられる +# 例: backbone で重ね合わせたうえで活性部位ループの RMSD を算出 +# +# 出力: .npy (shape: [n_frames], 単位: nm) +# +# 使用例: +# bash rmsd.sh + +set -e + +TOPOLOGY="gmx.gro" +TRAJECTORY="prd.xtc" +REFERENCE="ref.gro" + +mkdir -p cvs + +# ----------------------------------------------------------------------- +# backbone 全体の RMSD +# ----------------------------------------------------------------------- +mdtbx rmsd \ + -p ${TOPOLOGY} \ + -t ${TRAJECTORY} \ + -r ${REFERENCE} \ + -sft "protein and backbone" \ + -sfr "protein and backbone" \ + -sct "protein and backbone" \ + -scr "protein and backbone" \ + -o cvs/rmsd_backbone.npy + +echo "rmsd backbone done -> cvs/rmsd_backbone.npy" + +# ----------------------------------------------------------------------- +# 活性部位ループの RMSD (backbone で重ね合わせ後にループ領域を計算) +# -sft/-sfr: 重ね合わせに使う選択 (グローバルフィット) +# -sct/-scr: RMSD を計算する選択 (局所的な変化を検出) +# ----------------------------------------------------------------------- +mdtbx rmsd \ + -p ${TOPOLOGY} \ + -t ${TRAJECTORY} \ + -r ${REFERENCE} \ + -sft "protein and backbone" \ + -sfr "protein and backbone" \ + -sct "resid 100 to 120 and backbone" \ + -scr "resid 100 to 120 and backbone" \ + -o cvs/rmsd_loop.npy + +echo "rmsd loop done -> cvs/rmsd_loop.npy" + +echo "All done." diff --git a/example/cv/rmsf.sh b/example/cv/rmsf.sh new file mode 100755 index 0000000..023eff9 --- /dev/null +++ b/example/cv/rmsf.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# rmsf.sh +# 残基ごと・原子ごとの揺らぎ (RMSF) を計算する +# +# B-factor に対応する指標; 柔軟な領域や剛直な領域の同定に使用する +# --resolution residue: 残基ごとに集約 (デフォルト) +# --resolution atom : 原子ごとに出力 +# +# 出力: .npy (shape: [n_residues] or [n_atoms], 単位: nm) +# +# 使用例: +# bash rmsf.sh + +set -e + +TOPOLOGY="gmx.gro" +TRAJECTORY="prd.xtc" + +mkdir -p cvs + +# ----------------------------------------------------------------------- +# 残基単位の RMSF (B-factor 相当) +# ----------------------------------------------------------------------- +mdtbx rmsf \ + -p ${TOPOLOGY} \ + -t ${TRAJECTORY} \ + --selection "protein" \ + --resolution residue \ + -o cvs/rmsf_residue.npy + +echo "rmsf residue done -> cvs/rmsf_residue.npy" + +# ----------------------------------------------------------------------- +# 原子単位の RMSF (backbone のみ) +# ----------------------------------------------------------------------- +mdtbx rmsf \ + -p ${TOPOLOGY} \ + -t ${TRAJECTORY} \ + --selection "protein and backbone" \ + --resolution atom \ + -o cvs/rmsf_atom.npy + +echo "rmsf atom done -> cvs/rmsf_atom.npy" + +# ----------------------------------------------------------------------- +# Gromacs gmx rmsf を使う場合 (--gmx) +# ----------------------------------------------------------------------- +# mdtbx rmsf \ +# -p gmx.tpr \ +# -t ${TRAJECTORY} \ +# --selection "Backbone" \ +# --resolution residue \ +# --gmx \ +# -o cvs/rmsf_gmx.npy + +echo "All done." diff --git a/example/cv/xyz.sh b/example/cv/xyz.sh new file mode 100755 index 0000000..e192f1f --- /dev/null +++ b/example/cv/xyz.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# xyz.sh +# 特定の原子群の XYZ 座標時系列を抽出する +# +# スカラー CV では失われる空間情報が必要な場合や、 +# カスタム CV のための前処理として使用する +# +# 出力: .npy (shape: [n_frames, n_atoms, 3], 単位: nm) +# +# 使用例: +# bash xyz.sh + +set -e + +TOPOLOGY="gmx.gro" +TRAJECTORY="prd.xtc" + +mkdir -p cvs + +# ----------------------------------------------------------------------- +# リガンドの全原子座標 +# ----------------------------------------------------------------------- +mdtbx xyz \ + -p ${TOPOLOGY} \ + -t ${TRAJECTORY} \ + -s "resname LIG" \ + -o cvs/lig_xyz.npy + +echo "xyz done -> cvs/lig_xyz.npy" + +# ----------------------------------------------------------------------- +# 活性部位 Cα 座標 (特定残基) +# ----------------------------------------------------------------------- +mdtbx xyz \ + -p ${TOPOLOGY} \ + -t ${TRAJECTORY} \ + -s "resid 50 to 70 and name CA" \ + -o cvs/active_site_ca_xyz.npy + +echo "xyz done -> cvs/active_site_ca_xyz.npy" + +echo "All done." diff --git a/example/gen_distres/run.sh b/example/gen_distres/run.sh new file mode 100755 index 0000000..851a9cc --- /dev/null +++ b/example/gen_distres/run.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# run.sh +# Gromacs 用距離拘束ファイル (distres.itp) を生成し、topology.top に組み込む +# +# 前提: amb2gro または centering_gro で gmx.gro / gmx.top が生成済みであること +# +# 出力: +# distres.itp - [ intermolecular_interactions ] を含む拘束定義 +# gmx.top - distres.itp の #include が末尾に追記される +# +# NOTE: +# - 距離の単位は nm (0.26 nm = 2.6 Å) +# - up2 > up1 > lo の順に設定する +# - 各ペアに対して MDtraj atom selection 言語で1原子を一意に指定する +# 例: "resname MGX and name O2 and resid 200" +# - [ intermolecular_interactions ] は異なる分子グループ間にのみ適用される +# (同一分子内拘束は [ bonds ] に記述する) +# +# 使用例: +# bash run.sh + +set -e + +GRO="gmx.gro" +TOP="gmx.top" + +# ----------------------------------------------------------------------- +# 基本ケース: 金属イオン(MGX)と活性部位残基の距離拘束 +# タンパク質-金属間の配位結合を維持するための拘束 +# ----------------------------------------------------------------------- + +# 拘束するペアを配列で管理すると可読性が上がる +# 各ペア: (lo, up1, up2, sel1, sel2) で記述 +# lo=0.0 は下限なし (距離は非負なのでデフォルト 0.0 でよい) + +# sel1 と sel2 にセットされる MDtraj atom selection 文字列 +# カンマ区切りで複数ペアを指定できる +MGX_O1="resname MGX and name O1" +MGX_O2="resname MGX and name O2" +GLY_N="resname GLY and name N and resid 10" +GLY_N2="resname GLY and name N and resid 20" +HID_ND1="resname HID and name ND1" +HID_N="resname HID and name N" +GLH_OE1="resname GLH and name OE1" +GLH_OE2="resname GLH and name OE2" + +mdtbx gen_distres \ + -g ${GRO} \ + -p ${TOP} \ + -lo 0.0 \ + -up1 0.26 0.26 0.28 0.30 0.30 \ + -up2 0.28 0.28 0.30 0.32 0.32 \ + -o distres \ + -s1 "${MGX_O2},${MGX_O2}, ${MGX_O1}, ${MGX_O2},${GLY_N2}" \ + -s2 "${GLY_N}, ${HID_ND1},${GLH_OE2},${HID_N}, ${GLH_OE1}" + +echo "distres.itp generated and appended to ${TOP}" + +# ----------------------------------------------------------------------- +# mdp でのアクティベーション方法 +# mdp ファイルに以下を追記して拘束を有効化する: +# define = -DDISTRES -DDISTRES_FC=1000 +# DISTRES_FC は力定数 (kJ/mol/nm^2) +# ----------------------------------------------------------------------- + +# ----------------------------------------------------------------------- +# ユニフォームバウンドケース: 全ペアに同じ上下限を適用 +# -lo, -up1, -up2 に単一値を渡すと全ペアに適用される +# ----------------------------------------------------------------------- +# mdtbx gen_distres \ +# -g ${GRO} \ +# -p ${TOP} \ +# -lo 0.0 \ +# -up1 0.28 \ +# -up2 0.30 \ +# -o distres_uniform \ +# -s1 "${MGX_O2},${MGX_O2}" \ +# -s2 "${GLY_N}, ${HID_ND1}" + +echo "All done." diff --git a/example/mbar/calc_dg_points_mbar.py b/example/mbar/calc_dg_points_mbar.py index 048efee..900ebe7 100755 --- a/example/mbar/calc_dg_points_mbar.py +++ b/example/mbar/calc_dg_points_mbar.py @@ -1,131 +1,118 @@ -import numpy as np -import sys +import argparse + import matplotlib.pyplot as plt +import numpy as np -# ============================================================================= -# Constants -# ============================================================================= -# 前回のスクリプトと同じ温度設定 -temperature = 310.0 - -# 単位変換係数 -# Gas constant R in kcal/(mol*K) -R_kcal = 1.987204259e-3 -RT_kcal = R_kcal * temperature - -# ============================================================================= -# Arguments & Usage -# ============================================================================= -if len(sys.argv) < 3: - print("Usage: python calc_delta_g.py [input_file]") - print("Example: python calc_delta_g.py 1.2 1.7") - sys.exit(1) - -dist_A_target = float(sys.argv[1]) -dist_B_target = float(sys.argv[2]) -input_file = sys.argv[3] if len(sys.argv) > 3 else "pmf_output.dat" - -# ============================================================================= -# Load Data -# ============================================================================= -print(f"Loading {input_file} ...") -try: - # 読み込み (headerの#は自動的に無視される) - data = np.loadtxt(input_file) -except OSError: - print(f"Error: File {input_file} not found.") - sys.exit(1) - -dist = data[:, 0] # Distance (nm) -pmf_kt = data[:, 1] # PMF (kT) -err_kt = data[:, 2] # Error (kT) - -# ============================================================================= -# Find Nearest Points -# ============================================================================= -# 指定された距離に最も近いビンのインデックスを探す -idx_A = (np.abs(dist - dist_A_target)).argmin() -idx_B = (np.abs(dist - dist_B_target)).argmin() - -dist_A_actual = dist[idx_A] -pmf_A_kt = pmf_kt[idx_A] -err_A_kt = err_kt[idx_A] - -dist_B_actual = dist[idx_B] -pmf_B_kt = pmf_kt[idx_B] -err_B_kt = err_kt[idx_B] - -# ============================================================================= -# Calculate Delta G -# ============================================================================= -# Delta G (kT) = PMF(B) - PMF(A) -# つまり、AからBへ移行する際のエネルギー変化 -dG_kt = pmf_B_kt - pmf_A_kt - -# 誤差伝播 (二乗和の平方根) -error_dG_kt = np.sqrt(err_A_kt**2 + err_B_kt**2) - -# Convert to kcal/mol -pmf_kcal = pmf_kt * RT_kcal -dG_kcal = dG_kt * RT_kcal -error_dG_kcal = error_dG_kt * RT_kcal - -# ============================================================================= -# Output Results -# ============================================================================= -print("-" * 60) -print(f"Temperature: {temperature} K") -print(f"RT factor : {RT_kcal:.4f} kcal/mol") -print("-" * 60) -print( - f"{'Point':<10} | {'Target(nm)':<10} | {'Actual(nm)':<10} | {'PMF(kT)':<10} | {'PMF(kcal/mol)':<15}" -) -print("-" * 60) -print( - f"{'Start (A)':<10} | {dist_A_target:<10.4f} | {dist_A_actual:<10.4f} | {pmf_A_kt:<10.4f} | {pmf_A_kt * RT_kcal:<15.4f}" -) -print( - f"{'End (B)':<10} | {dist_B_target:<10.4f} | {dist_B_actual:<10.4f} | {pmf_B_kt:<10.4f} | {pmf_B_kt * RT_kcal:<15.4f}" -) -print("-" * 60) -print(f"Delta G (A -> B): {dG_kcal:.4f} +/- {error_dG_kcal:.4f} kcal/mol") -print("-" * 60) - -# ============================================================================= -# Plotting -# ============================================================================= -plt.figure(figsize=(8, 5)) -plt.errorbar( - dist, - pmf_kcal, - yerr=err_kt * RT_kcal, - fmt="-", - color="black", - ecolor="lightgray", - label="PMF Profile", -) - -# ポイントのハイライト -plt.scatter( - [dist_A_actual], - [pmf_A_kt * RT_kcal], - color="blue", - s=100, - zorder=5, - label="Start (A)", -) -plt.scatter( - [dist_B_actual], [pmf_B_kt * RT_kcal], color="red", s=100, zorder=5, label="End (B)" -) - -plt.title(f"PMF Profile with Selected Points\nDelta G = {dG_kcal:.2f} kcal/mol") -plt.xlabel("Distance (nm)") -plt.ylabel("PMF (kcal/mol)") -plt.grid(True, linestyle="--", alpha=0.6) -plt.legend() -plt.tight_layout() - -# 画像保存 -plot_filename = "delta_g_check.png" -plt.savefig(plot_filename, dpi=150) -print(f"Plot saved to: {plot_filename}") +R_KCAL = 1.987204259e-3 # Gas constant [kcal/(mol*K)] + + +def calc_delta_g(dist, pmf_kt, err_kt, dist_A_target, dist_B_target, temperature): + RT_kcal = R_KCAL * temperature + + idx_A = (np.abs(dist - dist_A_target)).argmin() + idx_B = (np.abs(dist - dist_B_target)).argmin() + + dist_A_actual = dist[idx_A] + pmf_A_kt = pmf_kt[idx_A] + err_A_kt = err_kt[idx_A] + + dist_B_actual = dist[idx_B] + pmf_B_kt = pmf_kt[idx_B] + err_B_kt = err_kt[idx_B] + + dG_kt = pmf_B_kt - pmf_A_kt + error_dG_kt = np.sqrt(err_A_kt**2 + err_B_kt**2) + dG_kcal = dG_kt * RT_kcal + error_dG_kcal = error_dG_kt * RT_kcal + + return ( + dist_A_actual, + pmf_A_kt, + dist_B_actual, + pmf_B_kt, + dG_kcal, + error_dG_kcal, + RT_kcal, + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Calculate delta G between two points on a PMF profile" + ) + parser.add_argument("dist_A", type=float, help="Start point distance (nm)") + parser.add_argument("dist_B", type=float, help="End point distance (nm)") + parser.add_argument( + "input_file", nargs="?", default="pmf_output.dat", help="PMF data file" + ) + parser.add_argument( + "--temperature", type=float, default=310.0, help="Temperature [K]" + ) + parser.add_argument( + "--output", default="delta_g_check.png", help="Output plot filename" + ) + args = parser.parse_args() + + try: + data = np.loadtxt(args.input_file) + except OSError: + print(f"Error: File {args.input_file} not found.") + raise SystemExit(1) + + dist = data[:, 0] + pmf_kt = data[:, 1] + err_kt = data[:, 2] + + dist_A, pmf_A_kt, dist_B, pmf_B_kt, dG_kcal, error_dG_kcal, RT_kcal = calc_delta_g( + dist, pmf_kt, err_kt, args.dist_A, args.dist_B, args.temperature + ) + pmf_kcal = pmf_kt * RT_kcal + + print(f"{'─' * 60}") + print(f"Temperature: {args.temperature} K | RT: {RT_kcal:.4f} kcal/mol") + print(f"{'─' * 60}") + print( + f"{'Point':<10} {'Target(nm)':<12} {'Actual(nm)':<12} " + f"{'PMF(kT)':<10} {'PMF(kcal/mol)':<15}" + ) + print(f"{'─' * 60}") + print( + f"{'Start (A)':<10} {args.dist_A:<12.4f} {dist_A:<12.4f} " + f"{pmf_A_kt:<10.4f} {pmf_A_kt * RT_kcal:<15.4f}" + ) + print( + f"{'End (B)':<10} {args.dist_B:<12.4f} {dist_B:<12.4f} " + f"{pmf_B_kt:<10.4f} {pmf_B_kt * RT_kcal:<15.4f}" + ) + print(f"{'─' * 60}") + print(f"Delta G (A -> B): {dG_kcal:.4f} +/- {error_dG_kcal:.4f} kcal/mol") + print(f"{'─' * 60}") + + fig, ax = plt.subplots(figsize=(8, 5)) + ax.errorbar( + dist, + pmf_kcal, + yerr=err_kt * RT_kcal, + fmt="-", + color="black", + ecolor="lightgray", + label="PMF Profile", + ) + ax.scatter( + [dist_A], [pmf_A_kt * RT_kcal], color="blue", s=100, zorder=5, label="Start (A)" + ) + ax.scatter( + [dist_B], [pmf_B_kt * RT_kcal], color="red", s=100, zorder=5, label="End (B)" + ) + ax.set_title(f"PMF Profile\nΔG = {dG_kcal:.2f} kcal/mol") + ax.set_xlabel("Distance (nm)") + ax.set_ylabel("PMF (kcal/mol)") + ax.grid(True, linestyle="--", alpha=0.6) + ax.legend() + fig.tight_layout() + fig.savefig(args.output, dpi=150) + print(f"Plot saved to: {args.output}") + + +if __name__ == "__main__": + main() diff --git a/example/mbar/calc_mbar.py b/example/mbar/calc_mbar.py index b9fa580..22efe96 100755 --- a/example/mbar/calc_mbar.py +++ b/example/mbar/calc_mbar.py @@ -1,221 +1,179 @@ +import sys +from pathlib import Path + import numpy as np import pymbar from pymbar import timeseries -import sys -from pathlib import Path # ============================================================================= -# 1. Metadata Generation +# Configuration # ============================================================================= -kconst = 100 -subsample = True +KCONST = 100 # Umbrella spring constant [kJ/mol/unit^2] +SUBSAMPLE = True # Statistical subsampling to reduce autocorrelation +TEMPERATURE = 310.0 # [K] +TARGET_BIN_WIDTH = 0.03 +NBINS_MIN = 5 +NBINS_MAX = 50 +PADDING = 0.001 # Range padding to avoid empty edge bins +N_MAX_INIT = 200000 # Initial array size for trajectory data + + +def read_metadata(metadata_file: str) -> tuple[list, list, list]: + files, r0_k, K_k = [], [], [] + with open(metadata_file) as f: + for line in f: + tokens = line.split() + if len(tokens) >= 3: + files.append(tokens[0]) + r0_k.append(float(tokens[1])) + K_k.append(float(tokens[2])) + return files, r0_k, K_k -path_list = sorted(Path("./").glob("rep*"), key=lambda x: float(x.name.split("_")[-2])) -with open("metadata.dat", "w") as f: +def build_metadata(path_list: list[Path], kconst: float) -> str: + lines = [] for trial in path_list: try: - # target_distance = float(trial.name.split("_")[-2]) * 0.1 target_dihedral = float(trial.name.split("_")[-2]) - except Exception: + except (ValueError, IndexError): continue - - # us1_pullx.xvg から us5... まで探す - # ファイル名パターンがターゲットによって違う場合はここを調整してください - # for us_trial in range(1, 5+1): - # # for us_trial in [1]: - # filepath = f"{trial}/us{us_trial}_pullx.xvg" - # if Path(filepath).exists(): - # f.write(f"{filepath} {target_dihedral} {kconst}\n") - filepath = f"{trial}/us_dih1_pullx.xvg" if Path(filepath).exists(): - f.write(f"{filepath} {target_dihedral} {kconst}\n") + lines.append(f"{filepath} {target_dihedral} {kconst}") + content = "\n".join(lines) + with open("metadata.dat", "w") as f: + f.write(content) + return content -# ============================================================================= -# Constants -# ============================================================================= -kB = 8.314462618e-3 -temperature = 310.0 -beta = 1.0 / (kB * temperature) -# ============================================================================= -# Read Data -# ============================================================================= -files = [] -r0_k = [] -K_k = [] -try: - with open("metadata.dat", "r") as f: - for line in f: - tokens = line.split() - if len(tokens) >= 3: - files.append(tokens[0]) - r0_k.append(float(tokens[1])) - K_k.append(float(tokens[2])) -except FileNotFoundError: - print("Error: metadata.dat could not be created.") - sys.exit(1) +def load_trajectories(files: list, r0_k: list, K_k: list, beta: float, subsample: bool): + K = len(files) + r0_k = np.array(r0_k) + K_k = np.array(K_k) -K = len(files) -if K == 0: - print("Error: No data files found. Check your file paths and naming convention.") - sys.exit(1) + N_k = np.zeros(K, dtype=int) + r_kn = np.zeros([K, N_MAX_INIT]) + u_kn = np.zeros([K, N_MAX_INIT]) -r0_k = np.array(r0_k) -K_k = np.array(K_k) -N_max = 200000 + print(f"Reading {K} files...") + for k, filename in enumerate(files): + raw_data = [] + try: + with open(filename) as infile: + for line in infile: + if not line.startswith(("#", "@")): + parts = line.split() + if len(parts) >= 2: + val = float(parts[1]) + if val > 0.001: + raw_data.append(val) + except FileNotFoundError: + continue -N_k = np.zeros(K, dtype=int) -r_kn = np.zeros([K, N_max]) -u_kn = np.zeros([K, N_max]) + raw_data = np.array(raw_data) + if len(raw_data) == 0: + continue -print(f"Reading {K} files...") + if subsample: + g = timeseries.statistical_inefficiency(raw_data) + indices = timeseries.subsample_correlated_data(raw_data, g=g) + else: + indices = np.arange(len(raw_data)) -for k in range(K): - filename = files[k] - raw_data = [] - try: - with open(filename, "r") as infile: - for line in infile: - if not line.startswith(("#", "@")): - parts = line.split() - if len(parts) >= 2: - val = float(parts[1]) - # たまに極端な外れ値(0など)が入ることがあるのでフィルタリング - if val > 0.001: - raw_data.append(val) - raw_data = np.array(raw_data) - except FileNotFoundError: - continue - - if len(raw_data) == 0: - continue - - # --- 【修正1】サブサンプリングを無効化(全データ使用) --- - # データ数が少ないため、間引き処理を行わずにすべてのサンプルを使います。 - if subsample: - g = timeseries.statistical_inefficiency(raw_data) - indices = timeseries.subsample_correlated_data(raw_data, g=g) - - # # t0: 平衡化開始地点, g: 統計的非効率性, Neff: 実効サンプル数 - # t0, g, Neff = timeseries.detect_equilibration(raw_data) - # - # # 平衡化前のデータを捨てて、かつ間引いたインデックスを取得 - # data_equil = raw_data[t0:] - # indices_equil = timeseries.subsample_correlated_data(data_equil, g=g) - # - # # 元の配列(raw_data)に対するインデックスに変換 - # indices = indices_equil + t0 - # - # # 確認用ログ(どれくらい間引かれたか確認できます) - # print(f"Window {k}: Total={len(raw_data)}, EquilStart={t0}, g={g:.2f}, Kept={len(indices)}") - else: - indices = np.arange(len(raw_data)) - # ----------------------------------------------------- - - N_k[k] = len(indices) - - if N_k[k] > r_kn.shape[1]: - new_size = max(N_k[k], r_kn.shape[1] * 2) - r_kn = np.pad(r_kn, ((0, 0), (0, new_size - r_kn.shape[1])), "constant") - u_kn = np.pad(u_kn, ((0, 0), (0, new_size - u_kn.shape[1])), "constant") - - r_kn[k, 0 : N_k[k]] = raw_data[indices] - - if k % 10 == 0: - print(f"Processed window {k}/{K} (Samples: {N_k[k]})") - -# 配列を実サイズに切り詰め -N_max = np.max(N_k) -if N_max == 0: - print("Error: No valid data samples loaded.") - sys.exit(1) - -r_kn = r_kn[:, :N_max] -u_kn = u_kn[:, :N_max] + N_k[k] = len(indices) -# ============================================================================= -# Auto-detect Range & Prepare MBAR -# ============================================================================= -valid_mask = r_kn > 0.001 -valid_data = r_kn[valid_mask] -data_min = valid_data.min() -data_max = valid_data.max() - -# --- 【修正2】範囲の余白(パディング)を極小にする --- -# 以前は 0.05 でしたが、空ビンを作らないよう 0.001 (1e-3) 程度にします -padding = 0.001 -dist_min = data_min - padding -dist_max = data_max + padding - -# ビン数の決定 -range_width = dist_max - dist_min -# データ密度に応じてビン幅を調整(ここでは少し粗めの 0.03 nm 程度から試す) -target_bin_width = 0.03 -nbins = int(range_width / target_bin_width) -nbins = max(5, min(nbins, 50)) # ビン数が少なすぎず多すぎないように制限 - -print(f"Auto-detected range: {dist_min:.4f} nm - {dist_max:.4f} nm") -print(f"Set nbins to: {nbins}") - -print("Evaluating reduced potential energy matrix...") -u_kln = np.zeros([K, K, N_max]) -for k in range(K): - # k番目のシミュレーションのデータを取り出す - r = r_kn[k, : N_k[k]] - # 全ウィンドウ l でのエネルギーを計算 - # diff[l, n] = r[n] - r0_k[l] - diff = r[np.newaxis, :] - r0_k[:, np.newaxis] - u_kln[k, :, : N_k[k]] = u_kn[k, : N_k[k]] + beta * 0.5 * K_k[:, np.newaxis] * ( - diff**2 - ) + if N_k[k] > r_kn.shape[1]: + new_size = max(N_k[k], r_kn.shape[1] * 2) + r_kn = np.pad(r_kn, ((0, 0), (0, new_size - r_kn.shape[1])), "constant") + u_kn = np.pad(u_kn, ((0, 0), (0, new_size - u_kn.shape[1])), "constant") -# ============================================================================= -# Run MBAR -# ============================================================================= -print("Running MBAR (Robust mode)...") -# fes = pymbar.FES(u_kln, N_k, verbose=True, mbar_options={'solver_protocol': 'robust'}) -fes = pymbar.FES(u_kln, N_k, verbose=True) + r_kn[k, : N_k[k]] = raw_data[indices] + + if k % 10 == 0: + print(f" Window {k}/{K} (samples: {N_k[k]})") + + N_max = np.max(N_k) + if N_max == 0: + raise RuntimeError("No valid data samples loaded.") -bin_edges = np.linspace(dist_min, dist_max, nbins + 1) -bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + return r_kn[:, :N_max], u_kn[:, :N_max], N_k, r0_k, K_k -r_n = pymbar.utils.kn_to_n(r_kn, N_k=N_k) -u_n = pymbar.utils.kn_to_n(u_kn, N_k=N_k) -histogram_parameters = {"bin_edges": bin_edges} -fes.generate_fes( - u_n, r_n, fes_type="histogram", histogram_parameters=histogram_parameters -) -# fes.generate_fes(u_kn, r_n, fes_type="histogram", histogram_parameters=histogram_parameters, n_bootstraps=100) +def run_mbar(r_kn, u_kn, N_k, r0_k, K_k, beta): + valid_data = r_kn[r_kn > 0.001] + dist_min = valid_data.min() - PADDING + dist_max = valid_data.max() + PADDING + + nbins = int((dist_max - dist_min) / TARGET_BIN_WIDTH) + nbins = max(NBINS_MIN, min(nbins, NBINS_MAX)) + + print(f"Range: {dist_min:.4f} - {dist_max:.4f}, bins: {nbins}") + + N_max = r_kn.shape[1] + K = len(N_k) + u_kln = np.zeros([K, K, N_max]) + for k in range(K): + r = r_kn[k, : N_k[k]] + diff = r[np.newaxis, :] - r0_k[:, np.newaxis] + u_kln[k, :, : N_k[k]] = u_kn[k, : N_k[k]] + beta * 0.5 * K_k[:, np.newaxis] * ( + diff**2 + ) + + print("Running MBAR...") + fes = pymbar.FES(u_kln, N_k, verbose=True) + + bin_edges = np.linspace(dist_min, dist_max, nbins + 1) + bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + + r_n = pymbar.utils.kn_to_n(r_kn, N_k=N_k) + u_n = pymbar.utils.kn_to_n(u_kn, N_k=N_k) + fes.generate_fes( + u_n, r_n, fes_type="histogram", histogram_parameters={"bin_edges": bin_edges} + ) -print("Computing FES...") -# 念のためエラー計算なしでトライ -try: results = fes.get_fes( bin_centers, reference_point="from-lowest", uncertainty_method="analytical" ) - # results = fes.get_fes(bin_centers, reference_point="from-lowest", uncertainty_method="bootstrap") - # results = fes.get_fes(bin_centers, reference_point="from-lowest", uncertainty_method="bootstrap", n_bootstraps=100, bootstrap_solver_protocol="robust") -except Exception as e: - print(f"Analytical error calculation failed: {e}") - print("Retrying without uncertainty calculation...") - # results = fes.get_fes(bin_centers, reference_point="from-lowest", uncertainty_method=None) - exit(1) - -f_i = results["f_i"] -# 誤差が計算できなかった場合はゼロ埋め -df_i = results.get("df_i", np.zeros_like(f_i)) - -print("\nFree Energy Profile (Histogram) [unit: kT]") -print(f"{'Dist(nm)':>10s} {'PMF(kT)':>10s} {'Error':>10s}") -out_data = [] -for i in range(nbins): - # nan や inf を除外して出力 - if np.isfinite(f_i[i]): - print(f"{bin_centers[i]:10.4f} {f_i[i]:10.4f} {df_i[i]:10.4f}") - out_data.append([bin_centers[i], f_i[i], df_i[i]]) - -np.savetxt("pmf_output.dat", out_data, header="Dihedral(degree) PMF(kT) Error(kT)") -print("\nDone! Saved to pmf_output.dat") + f_i = results["f_i"] + df_i = results.get("df_i", np.zeros_like(f_i)) + + return bin_centers, f_i, df_i + + +def main(): + kB = 8.314462618e-3 # kJ/(mol·K) + beta = 1.0 / (kB * TEMPERATURE) + + # Build metadata from directory listing + path_list = sorted( + Path("./").glob("rep*"), key=lambda x: float(x.name.split("_")[-2]) + ) + if path_list: + build_metadata(path_list, KCONST) + + try: + files, r0_k, K_k = read_metadata("metadata.dat") + except FileNotFoundError: + print("Error: metadata.dat not found.") + sys.exit(1) + + if not files: + print("Error: No data files found. Check your naming convention.") + sys.exit(1) + + r_kn, u_kn, N_k, r0_k, K_k = load_trajectories(files, r0_k, K_k, beta, SUBSAMPLE) + bin_centers, f_i, df_i = run_mbar(r_kn, u_kn, N_k, r0_k, K_k, beta) + + print(f"\n{'Coord':>10s} {'PMF(kT)':>10s} {'Error':>10s}") + out_data = [] + for i in range(len(bin_centers)): + if np.isfinite(f_i[i]): + print(f"{bin_centers[i]:10.4f} {f_i[i]:10.4f} {df_i[i]:10.4f}") + out_data.append([bin_centers[i], f_i[i], df_i[i]]) + + np.savetxt("pmf_output.dat", out_data, header="Coord PMF(kT) Error(kT)") + print("\nDone! Saved to pmf_output.dat") + + +if __name__ == "__main__": + main() diff --git a/example/mbar/plot_mbar.py b/example/mbar/plot_mbar.py index 7a3a034..62473a2 100755 --- a/example/mbar/plot_mbar.py +++ b/example/mbar/plot_mbar.py @@ -1,97 +1,63 @@ -import numpy as np -import matplotlib.pyplot as plt -import sys - -# ========================================== -# 設定 -# ========================================== -input_file = sys.argv[1] -output_image = "pmf_profile.png" - -# 温度設定 (計算時と同じ温度を指定してください) -temperature = 310.0 # Kelvin - -# 単位変換係数の計算 (kT -> kcal/mol) -# ガス定数 R = 0.0019872 kcal/(mol·K) -gas_constant = 0.0019872041 -kT_to_kcal = gas_constant * temperature - -print(f"Temperature: {temperature} K") -print(f"Conversion factor (1 kT): {kT_to_kcal:.4f} kcal/mol") - -# ========================================== -# データの読み込み -# ========================================== -try: - data = np.loadtxt(input_file) - r = data[:, 0] # 1列目: 距離 (nm) - pmf_kT = data[:, 1] # 2列目: PMF (kT) - error_kT = data[:, 2] # 3列目: 誤差 (kT) -except FileNotFoundError: - print(f"Error: {input_file} が見つかりません。") - exit() - -# 単位変換を実行 -pmf_kcal = pmf_kT * kT_to_kcal -error_kcal = error_kT * kT_to_kcal - -# ========================================== -# プロットの作成 -# ========================================== -fig, ax = plt.subplots(figsize=(8, 6)) - -# 1. 誤差範囲を塗りつぶし (Shaded Error Bar) -# kcal/mol なので少し線や色を濃いめに見やすく設定 -ax.fill_between( - r, - pmf_kcal - error_kcal, - pmf_kcal + error_kcal, - color="#1f77b4", - alpha=0.3, - linewidth=0, - label="Standard Error", -) +import argparse -# 2. PMFのメインライン -ax.plot(r, pmf_kcal, color="#1f77b4", linewidth=2.5, label="PMF") - -# ========================================== -# 装飾 -# ========================================== -ax.set_title( - f"Potential of Mean Force (T={int(temperature)}K)", fontsize=16, fontweight="bold" -) -ax.set_xlabel("Dihedral (nm)", fontsize=14) -ax.set_ylabel("Free Energy (kcal/mol)", fontsize=14) - -plt.ylim(0, 9) - -# # ゼロライン(基準線)を引く -# ax.axhline(0, color='gray', linestyle='--', linewidth=1, alpha=0.7) - -# グリッド線 -ax.grid(True, linestyle=":", alpha=0.6) - -# 軸の文字サイズと目盛りの向き -ax.tick_params(axis="both", which="major", labelsize=12, direction="in") - -# 凡例 -ax.legend(fontsize=12, loc="best", frameon=True, framealpha=0.9) - -# 余白の調整 -plt.tight_layout() - -# ========================================== -# 保存 -# ========================================== -plt.savefig(output_image, dpi=300) -print(f"グラフを保存しました: {output_image}") +import matplotlib.pyplot as plt +import numpy as np -ax.set_title("") -ax.set_xlabel("") -ax.set_ylabel("") -ax.tick_params(axis="both", labelbottom=False, labelleft=False) -legend = ax.get_legend() -if legend: - legend.remove() -plt.savefig(output_image.replace(".png", "_no_title.png"), dpi=300, transparent=True) +GAS_CONSTANT = 0.0019872041 # kcal/(mol·K) + + +def main(): + parser = argparse.ArgumentParser(description="Plot PMF profile from MBAR output") + parser.add_argument("input_file", help="PMF data file (e.g. pmf_output.dat)") + parser.add_argument( + "--output", default="pmf_profile.png", help="Output image filename" + ) + parser.add_argument( + "--temperature", type=float, default=310.0, help="Temperature [K]" + ) + parser.add_argument("--ymax", type=float, default=None, help="Y-axis upper limit") + args = parser.parse_args() + + kT_to_kcal = GAS_CONSTANT * args.temperature + print(f"Temperature: {args.temperature} K | 1 kT = {kT_to_kcal:.4f} kcal/mol") + + try: + data = np.loadtxt(args.input_file) + except FileNotFoundError: + print(f"Error: {args.input_file} not found.") + raise SystemExit(1) + + r = data[:, 0] + pmf_kcal = data[:, 1] * kT_to_kcal + error_kcal = data[:, 2] * kT_to_kcal + + fig, ax = plt.subplots(figsize=(8, 6)) + ax.fill_between( + r, + pmf_kcal - error_kcal, + pmf_kcal + error_kcal, + color="#1f77b4", + alpha=0.3, + linewidth=0, + label="Standard Error", + ) + ax.plot(r, pmf_kcal, color="#1f77b4", linewidth=2.5, label="PMF") + ax.set_title( + f"Potential of Mean Force (T={int(args.temperature)}K)", + fontsize=16, + fontweight="bold", + ) + ax.set_xlabel("Reaction Coordinate", fontsize=14) + ax.set_ylabel("Free Energy (kcal/mol)", fontsize=14) + if args.ymax is not None: + ax.set_ylim(top=args.ymax) + ax.grid(True, linestyle=":", alpha=0.6) + ax.tick_params(axis="both", which="major", labelsize=12, direction="in") + ax.legend(fontsize=12, loc="best", frameon=True, framealpha=0.9) + fig.tight_layout() + fig.savefig(args.output, dpi=300) + print(f"Saved: {args.output}") + + +if __name__ == "__main__": + main() diff --git a/example/pacs/cv_reshape.py b/example/pacs/cv_reshape.py index 19d770b..04415b7 100755 --- a/example/pacs/cv_reshape.py +++ b/example/pacs/cv_reshape.py @@ -5,33 +5,19 @@ n_frame = 100 -for npy in Path("cvs/comdist/").glob("*.npy"): - a = np.load(npy) - - print(a.shape) - - total_frame = a.shape[0] - total_frame = total_frame - 1 - - a = a[1:].reshape(total_frame // n_frame, n_frame, 1) - - print(a.shape) - - prefix = npy.stem - np.save(f"{prefix}_reshaped.npy", a) - - -for npy in Path("cvs/comvec/").glob("*.npy"): - a = np.load(npy) - - print(a.shape) - - total_frame = a.shape[0] - total_frame = total_frame - 1 - - a = a[1:].reshape(total_frame // n_frame, n_frame, 3) - - print(a.shape) - - prefix = npy.stem - np.save(f"{prefix}_reshaped.npy", a) +# (directory, last_dim) pairs: comdist is scalar (1), comvec is 3D vector (3) +targets = [ + ("cvs/comdist/", 1), + ("cvs/comvec/", 3), +] + +for cv_dir, ndim in targets: + for npy in Path(cv_dir).glob("*.npy"): + a = np.load(npy) + print(f"{npy}: {a.shape}", end=" -> ") + + total_frame = a.shape[0] - 1 # drop first frame + a = a[1:].reshape(total_frame // n_frame, n_frame, ndim) + + print(a.shape) + np.save(f"{npy.stem}_reshaped.npy", a) diff --git a/example/place_solvent/run.sh b/example/place_solvent/run.sh new file mode 100644 index 0000000..fa121ab --- /dev/null +++ b/example/place_solvent/run.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# run.sh +# 3D-RISM を使って結晶水・保存水サイトを推定し、PDB に配置する +# +# 前提: leap (build_solution) で生成した parm7/rst7 がある +# +# 使用例: +# bash run.sh + +set -e + +PRMTOP="leap.parm7" +COORD="leap.rst7" +OUTPUT="placed_water.pdb" + +# ----------------------------------------------------------------------- +# 基本ケース: 1D-RISM から xvv を自動生成してそのまま 3D-RISM を実行 +# ----------------------------------------------------------------------- +mdtbx place_solvent \ + -p ${PRMTOP} \ + -x ${COORD} \ + -o ${OUTPUT} \ + --solvent-model SPC \ + --temperature 300.0 \ + --closure kh \ + --grdspc 0.5 \ + --buffer 14.0 \ + --solvcut 14.0 \ + --threshold 1.5 \ + --exclusion-radius 2.6 + +echo "Basic run done -> ${OUTPUT}" + +# ----------------------------------------------------------------------- +# xvv 再利用ケース: 同一溶媒モデル・温度の計算を複数構造に適用する場合 +# 1D-RISM は一度だけ実行して xvv を保存しておくと時間を節約できる +# ----------------------------------------------------------------------- +# (初回に --keepfiles で xvv を取得しておく場合の例) +# mdtbx place_solvent \ +# -p ${PRMTOP} -x ${COORD} \ +# --keepfiles \ +# --solvent-model SPC --temperature 300.0 +# # 生成された xvv を保存 +# cp /tmp/rism3d_*/SPC_300.00.xvv ./SPC_300.xvv + +XVV="SPC_300.xvv" +if [ -f "${XVV}" ]; then + mdtbx place_solvent \ + -p ${PRMTOP} \ + -x ${COORD} \ + -o placed_water_reuse.pdb \ + --xvv ${XVV} \ + --temperature 300.0 \ + --closure kh \ + --threshold 1.5 \ + --exclusion-radius 2.6 + echo "xvv reuse run done -> placed_water_reuse.pdb" +fi + +# ----------------------------------------------------------------------- +# 上位 N サイトのみ配置するケース +# タンパク質活性部位周辺の主要な保存水だけを選びたい場合に有効 +# ----------------------------------------------------------------------- +mdtbx place_solvent \ + -p ${PRMTOP} \ + -x ${COORD} \ + -o placed_water_top20.pdb \ + --solvent-model SPC \ + --temperature 300.0 \ + --closure kh \ + --threshold 1.5 \ + --exclusion-radius 2.6 \ + --max-sites 20 + +echo "Top-20 sites run done -> placed_water_top20.pdb" + +# ----------------------------------------------------------------------- +# 高精度ケース: グリッドを細かくし、閾値を下げて弱い水和サイトも検出 +# 計算コストは grdspc の 3 乗に比例するので注意 +# ----------------------------------------------------------------------- +mdtbx place_solvent \ + -p ${PRMTOP} \ + -x ${COORD} \ + -o placed_water_hires.pdb \ + --solvent-model SPC \ + --temperature 300.0 \ + --closure kh \ + --grdspc 0.3 \ + --buffer 14.0 \ + --solvcut 14.0 \ + --tolerance 1e-6 \ + --threshold 1.0 \ + --exclusion-radius 2.6 + +echo "High-resolution run done -> placed_water_hires.pdb" + +echo "All done." diff --git a/example/remd/gen_temperatures.sh b/example/remd/gen_temperatures.sh new file mode 100755 index 0000000..42799e6 --- /dev/null +++ b/example/remd/gen_temperatures.sh @@ -0,0 +1,74 @@ +#!/bin/bash +# gen_temperatures.sh +# REMD シミュレーション用の温度リストを生成する +# +# アルゴリズム: de Pablo らの方法 (virtualchemistry.org/remd-temperature-generator) +# 系の自由度・拘束条件・溶媒モデルから交換確率が目標値になる温度を逐次計算する +# +# 前提: gmx.gro / gmx.top から原子数・水分子数を確認しておく +# 水分子数: grep -c "SOL" gmx.gro (3原子/分子なので /3 する) +# タンパク原子数: grep -c "Protein" gmx.ndx 等で確認 +# +# 使用例: +# bash gen_temperatures.sh + +set -e + +# ----------------------------------------------------------------------- +# 水溶液系タンパク質 (ff14SB + TIP3P/SPC 相当) の標準設定 +# --pc 1 : タンパク質中の水素結合のみ拘束 (typical LINCS/SHAKE 設定) +# --wc 3 : 水は剛体モデル (TIP3P/SPC/SPC-E) +# --hff 0 : 全水素原子を含む力場 (ff14SB) +# ----------------------------------------------------------------------- +NW=10000 # 水分子数 +NP=3000 # タンパク質原子数 +TLOW=300.0 # 最低温度 [K] +THIGH=400.0 # 最高温度 [K] +PDES=0.25 # 目標交換確率 (0.2-0.3 が一般的) + +echo "=== Standard protein-water REMD ===" +mdtbx gen_temperatures \ + --pdes ${PDES} \ + --tlow ${TLOW} \ + --thigh ${THIGH} \ + --nw ${NW} \ + --np ${NP} \ + --pc 1 \ + --wc 3 \ + --hff 0 + +# ----------------------------------------------------------------------- +# 小分子・ペプチド系: 原子数が少なく交換確率が高くなりやすい +# レプリカ数が少なくなるので温度範囲を絞るか pdes を下げて調整する +# ----------------------------------------------------------------------- +echo "" +echo "=== Small peptide REMD ===" +mdtbx gen_temperatures \ + --pdes 0.20 \ + --tlow 280.0 \ + --thigh 380.0 \ + --nw 3000 \ + --np 300 \ + --pc 1 \ + --wc 3 \ + --hff 0 + +# ----------------------------------------------------------------------- +# 全原子フレキシブル (拘束なし) 設定 +# --pc 0 : タンパク質拘束なし +# --wc 0 : 水も完全フレキシブル (SPC/Ef 等) +# 自由度が増えるためレプリカ数が増える傾向にある +# ----------------------------------------------------------------------- +echo "" +echo "=== Fully flexible (no constraints) ===" +mdtbx gen_temperatures \ + --pdes 0.25 \ + --tlow 300.0 \ + --thigh 400.0 \ + --nw 10000 \ + --np 3000 \ + --pc 0 \ + --wc 0 \ + --hff 0 + +echo "All done." diff --git a/example/remd/setup_reus_dist.sh b/example/remd/setup_reus_dist.sh index 38a69fe..4457153 100755 --- a/example/remd/setup_reus_dist.sh +++ b/example/remd/setup_reus_dist.sh @@ -6,6 +6,7 @@ SUBMIT_SCRIPT="$TOOLS/mdtbx/example/mdrun/remd_slurm.sh" TOPOLOGY="gmx.top" INDEX="index.ndx" ITP="*.itp" +MAXWARN=10 source /home/apps/Modules/init/bash module purge @@ -51,12 +52,12 @@ do sed -i -e "s/TARGET_DISTANCE/${TARGET_DISTANCE}/g" rep${rep}/reus.mdp cp $SUBMIT_SCRIPT rep${rep}/ gmx_mpi grompp \ - -f reus.mdp \ - -c gmx.gro \ - -n index.ndx \ - -p gmx.top \ + -f rep${rep}/reus.mdp \ + -c rep${rep}/gmx.gro \ + -n rep${rep}/index.ndx \ + -p rep${rep}/gmx.top \ -maxwarn ${MAXWARN} \ - -o reus.tpr + -o rep${rep}/reus.tpr done rm -f target_structures_distances.txt diff --git a/example/wham/plot_bspmf.py b/example/wham/plot_bspmf.py index 8f76b95..26afbdd 100755 --- a/example/wham/plot_bspmf.py +++ b/example/wham/plot_bspmf.py @@ -45,13 +45,6 @@ def plot_bootstrap_pmf(result_file="bsResult.xvg", output_file="pmf_bootstrap.pn # 保存 fig.tight_layout() fig.savefig(output_file, dpi=300) - - # ax.set_title("") - # ax.set_xlabel("") - # ax.set_ylabel("") - # ax.tick_params(axis='both', labelbottom=False, labelleft=False) - # plt.savefig(output_file.replace(".png", "_no_title.png"), dpi=300, transparent=True) - plt.close(fig) print(f"Saved plot to {output_file}") diff --git a/example/wham/plot_hist.py b/example/wham/plot_hist.py index a89077a..b06cfa6 100755 --- a/example/wham/plot_hist.py +++ b/example/wham/plot_hist.py @@ -37,14 +37,6 @@ def plot_histograms(input_file="hist.xvg", output_file="hist_plot.png"): # 保存 fig.tight_layout() fig.savefig(output_file, dpi=300) - - # notitle版 - # ax.set_title("") - # ax.set_xlabel("") - # ax.set_ylabel("") - # ax.tick_params(axis='both', labelbottom=False, labelleft=False) - # plt.savefig(output_file.replace(".png", "_no_title.png"), dpi=300, transparent=True) - plt.close(fig) diff --git a/example/wham/plot_pmf.py b/example/wham/plot_pmf.py index 63eed2f..309e12a 100755 --- a/example/wham/plot_pmf.py +++ b/example/wham/plot_pmf.py @@ -28,14 +28,7 @@ def plot_pmf(input_file="profile.xvg", output_file="pmf_plot.png"): # 保存 fig.tight_layout() fig.savefig(output_file, dpi=300) - - # ax.set_title("") - # ax.set_xlabel("") - # ax.set_ylabel("") - # ax.tick_params(axis='both', labelbottom=False, labelleft=False) - # plt.savefig(output_file.replace(".png", "_no_title.png"), dpi=300, transparent=True) - - plt.close(fig) # メモリ解放のためclose + plt.close(fig) # 実行 diff --git a/example/wham/run_wham.sh b/example/wham/run_wham.sh index 89204a4..000113a 100755 --- a/example/wham/run_wham.sh +++ b/example/wham/run_wham.sh @@ -6,9 +6,10 @@ ls -1 rep*/*_pullf.xvg > tmp_pullf.dat ls -1 rep*/*_pullx.xvg > tmp_pullx.dat # for angle +# Use -ix for pull coordinate or -if for pull force files mdtbx cmd gmx wham \ -it tmp_tpr.dat \ - -ix tmp_pullx.dat \ # or -if tmp_pullf.dat + -ix tmp_pullx.dat \ -o profile.xvg \ -hist hist.xvg \ -unit kCal \ diff --git a/pixi.lock b/pixi.lock index f9241fa..fd64c45 100644 --- a/pixi.lock +++ b/pixi.lock @@ -26,6 +26,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/c-blosc2-2.15.2-h3122c55_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2025.10.5-hbd8a1cb_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.4-h3394656_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.3.2-py310h3788b33_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/cuda-nvrtc-12.9.86-hecca717_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/cuda-version-12.9-h4f385c5_3.conda @@ -33,6 +34,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/cyrus-sasl-2.1.28-hd9c7081_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/dbus-1.16.2-h3c4dab8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/deeptime-0.4.4-py310h5eaa309_3.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.3.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/fftw-3.3.10-nompi_hf1063bd_110.conda - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2 @@ -60,6 +62,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/hdf4-4.2.15-h2a13503_7.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/hdf5-1.14.4-nompi_h2d575fe_105.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/icu-75.1-he02047a_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.3.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/joblib-1.5.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/keyutils-1.6.3-hb9d3cd8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.9-py310haaf941d_1.conda @@ -175,16 +178,19 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/perl-5.32.1-7_hd590300_perl5.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pillow-11.3.0-py310h6557065_3.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pixman-0.46.4-h54a6638_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/ply-3.11-pyhd8ed1ab_3.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pmw-2.0.1-py310hff52083_1008.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pulseaudio-client-17.0-h9a8bead_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/py-cpuinfo-9.0.0-pyhd8ed1ab_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pymol-open-source-3.1.0-py310h0298fdb_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.5-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pyqt-5.15.11-py310h046fae5_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pyqt5-sip-12.17.0-py310hea6c23e_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pytables-3.10.1-py310h431dcdc_4.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-9.0.2-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.10.0-h543edf9_3_cpython.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.9.0.post0-pyhe01879c_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2025.2-pyhd8ed1ab_0.conda @@ -282,7 +288,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/d4/3c/eef454cd7c3880c2d55b50e18a9c7a213bf91ded79efcfb573d8d6dd8a47/duckdb-1.4.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/8b/5fe2cc11fee489817272089c4203e679c63b570a5aaeb18d852ae3cbba6a/et_xmlfile-2.0.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/3c/1db1b0f878319bb227f35a0fca7cad64e1f528b518bcab1a708da305c86d/faicons-0.2.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c8/23/02012e9c7e584e6f85e1e7078beff3dc56aaad2e51b0a33bbcaa1dc2aa6e/fastexcel-0.16.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl @@ -397,7 +402,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/a1/6b/83661fa77dcefa195ad5f8cd9af3d1a7450fd57cc883ad04d65446ac2029/pydantic-2.12.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d3/81/8cece29a6ef1b3a92f956ea6da6250d5b2d2e7e4d513dd3b4f0c7a83dfea/pydantic_core-2.41.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/00/78/9cbcc1c073b9d4918e925af1a059762265dc65004e020511b2a06fbfd020/pydssp-0.9.1-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/fa/6c/def488007282c3d9ece0ca68fd6c4c7f42c500c0125cba3b8767731bfe5e/pyiceberg-0.10.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/81/d9/adf833614ea03f92eccd45274dd31eef2e6d06a301216f21dc6cfdc7c17f/pymbar-4.0.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e4/06/43084e6cbd4b3bc0e80f6be743b2e79fbc6eed8de9ad8c629939fa55d972/pymdown_extensions-10.16.1-py3-none-any.whl @@ -468,10 +472,12 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-64/c-blosc2-2.15.2-h62acda9_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2025.10.5-hbd8a1cb_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/cairo-1.18.4-h950ec3b_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/contourpy-1.3.2-py310hf166250_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/cyrus-sasl-2.1.28-h610c526_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/deeptime-0.4.4-py310h626da49_3.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.3.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/fftw-3.3.10-nompi_h292e606_110.conda - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2 @@ -494,6 +500,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-64/hdf4-4.2.15-h8138101_7.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/hdf5-1.14.4-nompi_h1607680_105.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/icu-75.1-h120a0e1_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.3.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/joblib-1.5.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/khronos-opencl-icd-loader-2024.10.24-h6e16a3a_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/kiwisolver-1.4.9-py310hfcdb090_1.conda @@ -579,15 +586,18 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-64/perl-5.32.1-7_h10d778d_perl5.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/pillow-11.3.0-py310h566a92c_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/pixman-0.46.4-ha059160_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/ply-3.11-pyhd8ed1ab_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/pmw-2.0.1-py310h2ec42d9_1008.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/pthread-stubs-0.4-h00291cd_1002.conda - conda: https://conda.anaconda.org/conda-forge/noarch/py-cpuinfo-9.0.0-pyhd8ed1ab_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/pymol-open-source-3.1.0-py310h6bc8293_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.5-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/pyqt-5.15.11-py310h6c6f83e_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/pyqt5-sip-12.17.0-py310h40a7462_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/pytables-3.10.1-py310hb5a30d5_4.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-9.0.2-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/python-3.10.0-h38b4d05_3_cpython.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.9.0.post0-pyhe01879c_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2025.2-pyhd8ed1ab_0.conda @@ -667,7 +677,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/17/ea/fb0fda8886d1928f1b2a53a1163ef94f6f4b41f6d8b29eee457acfc2fa67/duckdb-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/8b/5fe2cc11fee489817272089c4203e679c63b570a5aaeb18d852ae3cbba6a/et_xmlfile-2.0.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/3c/1db1b0f878319bb227f35a0fca7cad64e1f528b518bcab1a708da305c86d/faicons-0.2.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/cc/44/2dc31ec48d8f63f1d93e11ef19636a442c39775d49f1472f4123a6b38c34/fastexcel-0.16.0-cp39-abi3-macosx_10_12_x86_64.whl @@ -767,7 +776,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/a1/6b/83661fa77dcefa195ad5f8cd9af3d1a7450fd57cc883ad04d65446ac2029/pydantic-2.12.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/a7/3d/9b8ca77b0f76fcdbf8bc6b72474e264283f461284ca84ac3fde570c6c49a/pydantic_core-2.41.4-cp310-cp310-macosx_10_12_x86_64.whl - pypi: https://files.pythonhosted.org/packages/00/78/9cbcc1c073b9d4918e925af1a059762265dc65004e020511b2a06fbfd020/pydssp-0.9.1-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d4/82/5dbcf36c13ddf528a6c4ba7f75ace2766859870e7e166096c73c8e63c457/pyiceberg-0.10.0-cp310-cp310-macosx_10_9_x86_64.whl - pypi: https://files.pythonhosted.org/packages/81/d9/adf833614ea03f92eccd45274dd31eef2e6d06a301216f21dc6cfdc7c17f/pymbar-4.0.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e4/06/43084e6cbd4b3bc0e80f6be743b2e79fbc6eed8de9ad8c629939fa55d972/pymdown_extensions-10.16.1-py3-none-any.whl @@ -837,10 +845,12 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/c-blosc2-2.15.2-h9cbb436_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2025.10.5-hbd8a1cb_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/cairo-1.18.4-h6a3b0d2_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/contourpy-1.3.2-py310h7f4e7e6_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/cyrus-sasl-2.1.28-ha1cbb27_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/deeptime-0.4.4-py310h7854e75_3.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.3.1-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/fftw-3.3.10-nompi_h6637ab6_110.conda - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/font-ttf-inconsolata-3.000-h77eed37_0.tar.bz2 @@ -863,6 +873,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/hdf4-4.2.15-h2ee6834_7.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/hdf5-1.14.4-nompi_ha698983_105.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/icu-75.1-hfee45f7_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.3.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/joblib-1.5.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/khronos-opencl-icd-loader-2024.10.24-h5505292_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/kiwisolver-1.4.9-py310h92dc006_1.conda @@ -948,15 +959,18 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/perl-5.32.1-7_h4614cfb_perl5.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pillow-11.3.0-py310h5de80a5_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pixman-0.46.4-h81086ad_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/ply-3.11-pyhd8ed1ab_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pmw-2.0.1-py310hbe9552e_1008.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pthread-stubs-0.4-hd74edd7_1002.conda - conda: https://conda.anaconda.org/conda-forge/noarch/py-cpuinfo-9.0.0-pyhd8ed1ab_1.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pymol-open-source-3.1.0-py310hadfe0e0_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.5-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pyqt-5.15.11-py310hc3ce690_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pyqt5-sip-12.17.0-py310hc7786af_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pytables-3.10.1-py310h1535f74_4.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/pytest-9.0.2-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.10.0-h43b31ca_3_cpython.tar.bz2 - conda: https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.9.0.post0-pyhe01879c_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/python-tzdata-2025.2-pyhd8ed1ab_0.conda @@ -1036,7 +1050,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/b4/5f/052e6436a71f461e61cd3a982954c029145a84b58cefa1dfb3eb2d96e4fc/duckdb-1.4.1-cp310-cp310-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/8b/5fe2cc11fee489817272089c4203e679c63b570a5aaeb18d852ae3cbba6a/et_xmlfile-2.0.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/65/3c/1db1b0f878319bb227f35a0fca7cad64e1f528b518bcab1a708da305c86d/faicons-0.2.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e2/d8/ef4489cd00fe9fe52bef176ed32a8bb5837dd97518bb950bbd68f546ed1c/fastexcel-0.16.0-cp39-abi3-macosx_11_0_arm64.whl @@ -1136,7 +1149,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/a1/6b/83661fa77dcefa195ad5f8cd9af3d1a7450fd57cc883ad04d65446ac2029/pydantic-2.12.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/59/92/b7b0fe6ed4781642232755cb7e56a86e2041e1292f16d9ae410a0ccee5ac/pydantic_core-2.41.4-cp310-cp310-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/00/78/9cbcc1c073b9d4918e925af1a059762265dc65004e020511b2a06fbfd020/pydssp-0.9.1-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ba/a0/ca556da105ce64269e46977204777f6d5e1d8595f711f6d0edb3bbc58eff/pyiceberg-0.10.0-cp310-cp310-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/81/d9/adf833614ea03f92eccd45274dd31eef2e6d06a301216f21dc6cfdc7c17f/pymbar-4.0.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e4/06/43084e6cbd4b3bc0e80f6be743b2e79fbc6eed8de9ad8c629939fa55d972/pymdown_extensions-10.16.1-py3-none-any.whl @@ -2126,6 +2138,17 @@ packages: version: 3.1.1 sha256: c8c5a44295039331ee9dad40ba100a9c7297b6f988e50e87ccdf3765a668350e requires_python: '>=3.8' +- conda: https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda + sha256: ab29d57dc70786c1269633ba3dff20288b81664d3ff8d21af995742e2bb03287 + md5: 962b9857ee8e7018c22f2776ffa0b2d7 + depends: + - python >=3.9 + license: BSD-3-Clause + license_family: BSD + purls: + - pkg:pypi/colorama?source=hash-mapping + size: 27011 + timestamp: 1733218222191 - pypi: https://files.pythonhosted.org/packages/6d/c1/e419ef3723a074172b68aaa89c9f3de486ed4c2399e2dbd8113a4fdcaf9e/colorlog-6.10.1-py3-none-any.whl name: colorlog version: 6.10.1 @@ -2509,14 +2532,17 @@ packages: version: 2.0.0 sha256: 7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa requires_python: '>=3.8' -- pypi: https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl - name: exceptiongroup - version: 1.3.0 - sha256: 4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10 - requires_dist: - - typing-extensions>=4.6.0 ; python_full_version < '3.13' - - pytest>=6 ; extra == 'test' - requires_python: '>=3.7' +- conda: https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.3.1-pyhd8ed1ab_0.conda + sha256: ee6cf346d017d954255bbcbdb424cddea4d14e4ed7e9813e429db1d795d01144 + md5: 8e662bd460bda79b1ea39194e3c4c9ab + depends: + - python >=3.10 + - typing_extensions >=4.6.0 + license: MIT and PSF-2.0 + purls: + - pkg:pypi/exceptiongroup?source=hash-mapping + size: 21333 + timestamp: 1763918099466 - pypi: https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl name: executing version: 2.2.1 @@ -3772,6 +3798,17 @@ packages: - pytest-enabler>=2.2 ; extra == 'enabler' - pytest-mypy ; extra == 'type' requires_python: '>=3.9' +- conda: https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.3.0-pyhd8ed1ab_0.conda + sha256: e1a9e3b1c8fe62dc3932a616c284b5d8cbe3124bbfbedcf4ce5c828cb166ee19 + md5: 9614359868482abba1bd15ce465e3c42 + depends: + - python >=3.10 + license: MIT + license_family: MIT + purls: + - pkg:pypi/iniconfig?source=compressed-mapping + size: 13387 + timestamp: 1760831448842 - pypi: https://files.pythonhosted.org/packages/b8/f7/761037905ffdec673533bfa43af8d4c31c859c778dfc3bbb71899875ec18/ipykernel-7.0.1-py3-none-any.whl name: ipykernel version: 7.0.1 @@ -7388,7 +7425,7 @@ packages: - pypi: ./ name: mdtbx version: 0.1.0 - sha256: 6df5ea4f27001fabe2cdb24929248af40dc95dd8b6f84ba121ceecd7d66cfe35 + sha256: 4e16ef83ce31b9ab596262c28ca133ef3501934566e752a3ef5549d353dbec7d requires_python: ==3.10 - conda: https://conda.anaconda.org/conda-forge/linux-64/mdtraj-1.9.9-py310h523e8d7_1.conda sha256: f8aeffbd29bf9b7eb89020b7b5cbbea7ba961b18f3c1faec55076b2550c50581 @@ -9251,6 +9288,18 @@ packages: - pytest>=8.4.2 ; extra == 'test' - mypy>=1.18.2 ; extra == 'type' requires_python: '>=3.10' +- conda: https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhf9edf01_1.conda + sha256: e14aafa63efa0528ca99ba568eaf506eb55a0371d12e6250aaaa61718d2eb62e + md5: d7585b6550ad04c8c5e21097ada2888e + depends: + - python >=3.9 + - python + license: MIT + license_family: MIT + purls: + - pkg:pypi/pluggy?source=compressed-mapping + size: 25877 + timestamp: 1764896838868 - conda: https://conda.anaconda.org/conda-forge/noarch/ply-3.11-pyhd8ed1ab_3.conda sha256: bae453e5cecf19cab23c2e8929c6e30f4866d996a8058be16c797ed4b935461f md5: fd5062942bfa1b0bd5e0d2a4397b099e @@ -9660,13 +9709,17 @@ packages: - torch - einops - tqdm -- pypi: https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl - name: pygments - version: 2.19.2 - sha256: 86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b - requires_dist: - - colorama>=0.4.6 ; extra == 'windows-terminal' - requires_python: '>=3.8' +- conda: https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda + sha256: 5577623b9f6685ece2697c6eb7511b4c9ac5fb607c9babc2646c811b428fd46a + md5: 6b6ece66ebcae2d5f326c77ef2c5a066 + depends: + - python >=3.9 + license: BSD-2-Clause + license_family: BSD + purls: + - pkg:pypi/pygments?source=hash-mapping + size: 889287 + timestamp: 1750615908735 - pypi: https://files.pythonhosted.org/packages/ba/a0/ca556da105ce64269e46977204777f6d5e1d8595f711f6d0edb3bbc58eff/pyiceberg-0.10.0-cp310-cp310-macosx_11_0_arm64.whl name: pyiceberg version: 0.10.0 @@ -10119,6 +10172,27 @@ packages: - pkg:pypi/tables?source=hash-mapping size: 1556023 timestamp: 1733266235822 +- conda: https://conda.anaconda.org/conda-forge/noarch/pytest-9.0.2-pyhcf101f3_0.conda + sha256: 9e749fb465a8bedf0184d8b8996992a38de351f7c64e967031944978de03a520 + md5: 2b694bad8a50dc2f712f5368de866480 + depends: + - pygments >=2.7.2 + - python >=3.10 + - iniconfig >=1.0.1 + - packaging >=22 + - pluggy >=1.5,<2 + - tomli >=1 + - colorama >=0.4 + - exceptiongroup >=1 + - python + constrains: + - pytest-faulthandler >=2 + license: MIT + license_family: MIT + purls: + - pkg:pypi/pytest?source=hash-mapping + size: 299581 + timestamp: 1765062031645 - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.10.0-h543edf9_3_cpython.tar.bz2 build_number: 3 sha256: 0661f43d7bc446c22bfcf1730ad5c7a2ac695c696ba4609bac896cefff879431 diff --git a/pymol-plugins/pymol_plugins/__init__.py b/pymol-plugins/pymol_plugins/__init__.py index 62049bd..12f6f6d 100644 --- a/pymol-plugins/pymol_plugins/__init__.py +++ b/pymol-plugins/pymol_plugins/__init__.py @@ -4,6 +4,7 @@ from .visualizer import * # NoQA from .arrow import * # NoQA from .alias import * # NoQA +from .ai import * # NoQA # similar to alignto diff --git a/pymol-plugins/pymol_plugins/ai.py b/pymol-plugins/pymol_plugins/ai.py new file mode 100644 index 0000000..6bc0086 --- /dev/null +++ b/pymol-plugins/pymol_plugins/ai.py @@ -0,0 +1,731 @@ +"""AI assistant command for PyMOL. + +Usage: + claude + codex + +Examples: + claude resid 10を赤色にして + codex タンパク質を透明度0.5で表示して +""" + +from __future__ import annotations + +import base64 +import json +import re +import subprocess +import tempfile +import threading +import traceback +from itertools import count +from pathlib import Path + +from pymol import cmd + +SUPPORTED_AI_TYPES = {"claude", "codex"} +AI_TIMEOUT_SEC = 180 +AI_MAX_ATTEMPTS = 5 +AI_JOB_COUNTER = count(1) +AI_JOBS: dict[int, dict[str, object]] = {} +AI_JOBS_LOCK = threading.Lock() +PYMOL_SPECIAL_COMMAND_PREFIXES = ( + "@", + "/", +) +PYMOL_SPECIAL_COMMANDS = { + "python", + "python end", + "embed", + "skip", + "util.cbag", + "util.cbaw", + "util.cbao", +} + + +class _AIResponseFeedbackError(RuntimeError): + """Retryable error that should be fed back to the AI.""" + + def __init__(self, message: str, response: str) -> None: + super().__init__(message) + self.response = response + + +def _capture_scene_png(image_path: Path) -> None: + """Capture the current viewport using the PyMOL `png` command.""" + viewport = tuple(cmd.get_viewport()) + try: + cmd.refresh() + cmd.png(str(image_path), ray=0, quiet=1) + finally: + if len(viewport) == 2: + cmd.viewport(int(viewport[0]), int(viewport[1])) + cmd.refresh() + + if not image_path.exists(): + raise RuntimeError(f"failed to save screenshot to {image_path}") + + +def _get_scene_context() -> str: + """Collect compact PyMOL scene metadata for the AI prompt.""" + objects = cmd.get_object_list("all") + names = cmd.get_names("all") + enabled_objects = cmd.get_names("objects", enabled_only=1) + + lines = [ + f"Loaded objects: {objects}", + f"Enabled objects: {enabled_objects}", + f"All named selections/objects: {names}", + ] + + for obj in objects[:3]: + residues: list[tuple[str, str, str]] = [] + cmd.iterate_state( + 1, + f"({obj}) and name CA", + "residues.append((resi, resn, chain))", + space={"residues": residues}, + ) + if not residues: + continue + + summary = ", ".join( + f"{chain or '-'}:{resn}{resi}" for resi, resn, chain in residues[:10] + ) + if len(residues) > 10: + summary += f" ... ({len(residues)} residues total)" + lines.append(f"{obj} residues: {summary}") + + return "\n".join(lines) + + +def _build_prompt(instruction: str, scene_context: str, image_path: Path) -> str: + """Build the prompt sent to the local AI CLI.""" + return f"""You are a PyMOL expert assistant. + +The user is controlling a live PyMOL session. A PNG screenshot of the current viewport is attached. +The screenshot file name is `{image_path.name}`. + +Current PyMOL scene metadata: +{scene_context} + +User request: +{instruction} + +Return only executable PyMOL commands or Python code for PyMOL. + +Rules: +- No explanation, no prose, no markdown outside code fences. +- Prefer PyMOL commands when simple enough. +- If you return Python, use a single ```python fenced block and call `cmd.*`. +- If you return PyMOL commands, use a single ```pymol fenced block. +- Keep the output minimal and directly executable in the current session. +- Do not use shell commands. +""" + + +def _build_feedback_prompt( + instruction: str, + scene_context: str, + image_path: Path, + previous_response: str, + error_message: str, + attempt: int, +) -> str: + """Build a retry prompt that includes execution feedback.""" + base_prompt = _build_prompt(instruction, scene_context, image_path) + return f"""{base_prompt} + +The previous response did not succeed in the live PyMOL session. +This is retry attempt {attempt} of {AI_MAX_ATTEMPTS}. + +Previous response: +```text +{previous_response} +``` + +Feedback from the failed attempt: +{error_message} + +Return a corrected full replacement response that completes the original request +in the current session. + +Additional rules: +- Do not repeat the same failing code unchanged. +- Assume the previous attempt may have partially modified the scene. +- Return only executable code, exactly as in the original rules. +""" + + +def _normalize_bool_arg(value: object, default: bool = True) -> bool: + if value is None: + return default + + normalized = str(value).strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + return default + + +def _extract_code(response: str) -> list[tuple[str, str]]: + """Extract executable code blocks from the AI response.""" + blocks: list[tuple[str, str]] = [] + pattern = re.compile(r"```(pymol|python)\s*\n(.*?)```", re.DOTALL | re.IGNORECASE) + + for match in pattern.finditer(response): + lang = match.group(1).lower() + code = match.group(2).strip() + if code: + blocks.append((lang, code)) + + if blocks: + return blocks + + stripped = response.strip() + if stripped and _looks_like_raw_pymol_commands(stripped): + return [("pymol", stripped)] + return [] + + +def _run_subprocess( + args: list[str], stdin: str | None = None +) -> subprocess.CompletedProcess: + return subprocess.run( + args, + input=stdin, + capture_output=True, + text=True, + timeout=AI_TIMEOUT_SEC, + check=False, + ) + + +def _collect_text_from_message_content(content: object) -> str: + if isinstance(content, str): + return content + + if not isinstance(content, list): + return "" + + texts: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + if block.get("type") == "text" and isinstance(block.get("text"), str): + texts.append(block["text"]) + return "".join(texts) + + +def _parse_claude_stream_json(stdout: str) -> str: + """Convert Claude stream-json output into plain assistant text.""" + final_texts: list[str] = [] + streamed_text_parts: list[str] = [] + + for line in stdout.splitlines(): + line = line.strip() + if not line: + continue + + try: + event = json.loads(line) + except json.JSONDecodeError: + final_texts.append(line) + continue + + event_type = event.get("type") + + if event_type == "result": + result_text = event.get("result") + if isinstance(result_text, str) and result_text.strip(): + final_texts.append(result_text) + + message = event.get("message") + if isinstance(message, dict): + text = _collect_text_from_message_content(message.get("content")) + if text: + final_texts.append(text) + continue + + if event_type == "assistant": + message = event.get("message") + if isinstance(message, dict): + text = _collect_text_from_message_content(message.get("content")) + if text: + final_texts.append(text) + continue + + if event_type == "stream_event": + raw_event = event.get("event") + if not isinstance(raw_event, dict): + continue + + if raw_event.get("type") != "content_block_delta": + continue + + delta = raw_event.get("delta") + if not isinstance(delta, dict): + continue + + if delta.get("type") == "text_delta" and isinstance(delta.get("text"), str): + streamed_text_parts.append(delta["text"]) + + if final_texts: + return "\n".join(text for text in final_texts if text.strip()).strip() + return "".join(streamed_text_parts).strip() + + +def _has_executable_content(lang: str, code: str) -> bool: + if lang == "python": + if not code.strip(): + return False + try: + compile(code, "", "exec") + except SyntaxError: + return False + return True + + return any( + line.strip() and not line.strip().startswith("#") for line in code.splitlines() + ) + + +def _looks_like_raw_pymol_commands(text: str) -> bool: + lines = [line.strip() for line in text.splitlines() if line.strip()] + if not lines: + return False + + allowed_prefixes = ( + "@", + "/", + "run ", + "python", + "python end", + "embed", + "skip", + ) + common_commands = { + "align", + "bg_color", + "cartoon", + "center", + "color", + "delete", + "disable", + "dist", + "distance", + "extract", + "fetch", + "hide", + "label", + "load", + "orient", + "png", + "ray", + "remove", + "rebuild", + "refresh", + "reinitialize", + "select", + "set", + "show", + "spectrum", + "super", + "turn", + "util.cbag", + "util.cbaw", + "util.cbao", + "zoom", + } + + for line in lines: + if line.startswith("#"): + continue + if line.startswith(allowed_prefixes): + continue + + first = line.split(maxsplit=1)[0].rstrip(",").lower() + if first in common_commands: + continue + return False + + return True + + +def _run_claude(prompt: str, image_path: Path) -> str: + """Run Claude Code in stream-json mode to attach an image.""" + image_b64 = base64.b64encode(image_path.read_bytes()).decode("ascii") + payload = { + "type": "user", + "message": { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": image_b64, + }, + }, + ], + }, + } + + result = _run_subprocess( + [ + "claude", + "--print", + "--verbose", + "--input-format", + "stream-json", + "--no-session-persistence", + "--output-format", + "stream-json", + ], + stdin=json.dumps(payload) + "\n", + ) + if result.returncode != 0: + stderr = result.stderr.strip() or result.stdout.strip() + raise RuntimeError( + f"claude command failed (exit {result.returncode}):\n{stderr}" + ) + return _parse_claude_stream_json(result.stdout) + + +def _run_codex(prompt: str, image_path: Path) -> str: + """Run Codex CLI with an attached screenshot and capture the last message.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_path = Path(tmpdir) / "codex-last-message.txt" + result = _run_subprocess( + [ + "codex", + "exec", + "--skip-git-repo-check", + "-c", + "effort=high", + "--image", + str(image_path), + "--output-last-message", + str(output_path), + prompt, + ] + ) + if result.returncode != 0: + stderr = result.stderr.strip() or result.stdout.strip() + raise RuntimeError( + f"codex command failed (exit {result.returncode}):\n{stderr}" + ) + + if output_path.exists(): + return output_path.read_text().strip() + + return result.stdout + + +def _request_ai_response(ai_type: str, prompt: str, image_path: Path) -> str: + if ai_type == "claude": + return _run_claude(prompt, image_path) + return _run_codex(prompt, image_path) + + +def _capture_ai_context() -> tuple[Path, str]: + image_dir = Path(tempfile.mkdtemp(prefix="pymol-ai-")) + image_path = image_dir / "pymol_scene.png" + _capture_scene_png(image_path) + return image_path, _get_scene_context() + + +def _cleanup_ai_context(image_path: Path) -> None: + try: + image_path.unlink(missing_ok=True) + image_path.parent.rmdir() + except OSError: + pass + + +def _register_ai_job(job_id: int, ai_type: str, instruction: str) -> None: + with AI_JOBS_LOCK: + AI_JOBS[job_id] = { + "id": job_id, + "type": ai_type, + "instruction": instruction, + "status": "running", + "attempts": 0, + "max_attempts": AI_MAX_ATTEMPTS, + } + + +def _update_ai_job(job_id: int, **updates: object) -> None: + with AI_JOBS_LOCK: + job = AI_JOBS.get(job_id) + if job is None: + return + job.update(updates) + + +def _snapshot_ai_jobs() -> list[dict[str, object]]: + with AI_JOBS_LOCK: + return [dict(job) for job in AI_JOBS.values()] + + +def _execute_blocks(blocks: list[tuple[str, str]]) -> None: + """Execute extracted code blocks inside the active PyMOL session.""" + for lang, code in blocks: + if not _has_executable_content(lang, code): + continue + + print(f" [ai] Executing ({lang}):\n{code}\n") + if lang == "python": + exec(code, {"cmd": cmd, "__builtins__": __builtins__}) # noqa: S102 + continue + + for line in code.splitlines(): + line = line.strip() + if line and not line.startswith("#"): + _validate_pymol_command(line) + cmd.do(line) + + +def _validate_pymol_command(line: str) -> None: + stripped = line.strip() + if not stripped or stripped.startswith("#"): + return + + if stripped.startswith(PYMOL_SPECIAL_COMMAND_PREFIXES): + return + + normalized = stripped.lower() + if normalized in PYMOL_SPECIAL_COMMANDS: + return + + command = stripped.split(maxsplit=1)[0].rstrip(",").lower() + if command in PYMOL_SPECIAL_COMMANDS: + return + if command in cmd.keyword: + return + + raise RuntimeError( + f"Unknown or unsupported PyMOL command: {command!r} in line: {line}" + ) + + +def _run_ai_job( + job_id: int, + ai_type: str, + instruction: str, +) -> None: + previous_response = "" + feedback_error = "" + + for attempt in range(1, AI_MAX_ATTEMPTS + 1): + image_path: Path | None = None + try: + image_path, scene_context = _capture_ai_context() + prompt = ( + _build_prompt(instruction, scene_context, image_path) + if attempt == 1 + else _build_feedback_prompt( + instruction=instruction, + scene_context=scene_context, + image_path=image_path, + previous_response=previous_response, + error_message=feedback_error, + attempt=attempt, + ) + ) + _update_ai_job( + job_id, + status="running" if attempt == 1 else "retrying", + attempts=attempt, + ) + print(f" [ai:{job_id}] Attempt {attempt}/{AI_MAX_ATTEMPTS}.") + + try: + response = _request_ai_response(ai_type, prompt, image_path).strip() + except FileNotFoundError: + _update_ai_job( + job_id, + status="error", + error=f"'{ai_type}' command not found", + attempts=attempt, + ) + print( + f" [ai:{job_id}] Error: '{ai_type}' command not found. " + "Make sure it is installed and available in PATH." + ) + return + except subprocess.TimeoutExpired: + _update_ai_job( + job_id, + status="error", + error=f"timeout ({AI_TIMEOUT_SEC}s)", + attempts=attempt, + ) + print( + f" [ai:{job_id}] Error: AI command timed out ({AI_TIMEOUT_SEC}s)." + ) + return + except RuntimeError as exc: + _update_ai_job(job_id, status="error", error=str(exc), attempts=attempt) + print(f" [ai:{job_id}] Error: {exc}") + return + + _update_ai_job(job_id, response=response) + print(f" [ai:{job_id}] Response received ({len(response)} chars).") + + blocks = [ + (lang, code) + for lang, code in _extract_code(response) + if _has_executable_content(lang, code) + ] + if not blocks: + raise _AIResponseFeedbackError( + "No executable code found in the AI response. " + "Return directly executable PyMOL commands or Python code.", + response, + ) + + _update_ai_job(job_id, status="executing", blocks=len(blocks)) + print(f" [ai:{job_id}] Executing {len(blocks)} block(s).") + try: + _execute_blocks(blocks) + except Exception as exc: # noqa: BLE001 + error = f"{type(exc).__name__}: {exc}" + raise _AIResponseFeedbackError( + "Applying the previous response failed.\n" + f"{error}\n" + f"{traceback.format_exc().rstrip()}", + response, + ) from exc + + _update_ai_job(job_id, status="done", attempts=attempt, error=None) + print(f" [ai:{job_id}] Done on attempt {attempt}.") + return + except _AIResponseFeedbackError as exc: + previous_response = exc.response + feedback_error = str(exc) + if attempt >= AI_MAX_ATTEMPTS: + _update_ai_job( + job_id, + status="error", + error=feedback_error, + response=previous_response, + attempts=attempt, + ) + print( + f" [ai:{job_id}] Error after {attempt} attempts: {feedback_error}" + ) + return + + _update_ai_job( + job_id, + status="retrying", + error=feedback_error, + response=previous_response, + attempts=attempt, + ) + print(f" [ai:{job_id}] Attempt {attempt} failed. Sending feedback.") + print(f" [ai:{job_id}] Feedback: {feedback_error}") + continue + except Exception as exc: # noqa: BLE001 + error = f"{type(exc).__name__}: {exc}" + _update_ai_job(job_id, status="error", error=error, attempts=attempt) + print(f" [ai:{job_id}] Error while applying response: {error}") + print(traceback.format_exc().rstrip()) + return + finally: + if image_path is not None: + _cleanup_ai_context(image_path) + + +def _submit_ai_request( + instruction: str, + type: str = "claude", # noqa: A002 + async_: str = "1", +) -> None: + """Send a natural-language request to Claude/Codex and run the result.""" + ai_type = type.strip().lower() + if ai_type not in SUPPORTED_AI_TYPES: + supported = ", ".join(sorted(SUPPORTED_AI_TYPES)) + print(f" [ai] Unknown type '{type}'. Use one of: {supported}") + return + + if not instruction or not instruction.strip(): + print(" [ai] Instruction is empty.") + return + + use_async = _normalize_bool_arg(async_, default=True) + job_id = next(AI_JOB_COUNTER) + _register_ai_job(job_id, ai_type, instruction.strip()) + + print( + f" [ai:{job_id}] Using {ai_type}. " + f"{'Queued asynchronously' if use_async else 'Running synchronously'}: {instruction!r}" + ) + + if use_async: + worker = threading.Thread( + target=_run_ai_job, + args=(job_id, ai_type, instruction.strip()), + daemon=True, + name=f"pymol-ai-{job_id}", + ) + _update_ai_job(job_id, thread=worker) + worker.start() + print( + f" [ai:{job_id}] Background job started. Use `ai_status` to check progress." + ) + return + + _run_ai_job(job_id, ai_type, instruction.strip()) + + +def claude_cmd(instruction: str, async_: str = "1") -> None: + """Run the local Claude CLI for a PyMOL request.""" + _submit_ai_request(instruction=instruction, type="claude", async_=async_) + + +def codex_cmd(instruction: str, async_: str = "1") -> None: + """Run the local Codex CLI for a PyMOL request.""" + _submit_ai_request(instruction=instruction, type="codex", async_=async_) + + +def ai_status(job_id: str = "all") -> None: + """Show status of AI jobs.""" + jobs = _snapshot_ai_jobs() + if not jobs: + print(" [ai] No jobs.") + return + + if job_id.strip().lower() != "all": + try: + wanted = int(job_id) + except ValueError: + print(" [ai] job_id must be an integer or 'all'.") + return + jobs = [job for job in jobs if job.get("id") == wanted] + if not jobs: + print(f" [ai] No such job: {wanted}") + return + + for job in sorted(jobs, key=lambda item: int(item["id"])): + status = job.get("status", "unknown") + ai_type = job.get("type", "?") + instruction = job.get("instruction", "") + attempts = job.get("attempts", 0) + max_attempts = job.get("max_attempts", AI_MAX_ATTEMPTS) + print( + f" [ai:{job['id']}] {status} " + f"(attempt {attempts}/{max_attempts}, {ai_type}) {instruction}" + ) + if status == "error" and job.get("error"): + print(f" [ai:{job['id']}] error: {job['error']}") + + +cmd.extend("claude", claude_cmd) +cmd.extend("codex", codex_cmd) +cmd.extend("ai_status", ai_status) diff --git a/pyproject.toml b/pyproject.toml index 8fa0cd1..3e789f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ rules = "- prefer polars over pandas\n- make charts using altair" [tool.pixi.dependencies] pymol-open-source = "==3.1.0" +pytest = ">=8" mdtraj = "==1.9.9" # https://github.com/ParmEd/ParmEd/pull/1387 # mdtraj = "==1.11.0" deeptime = "==0.4.4" @@ -79,4 +80,6 @@ jupyter_remote = "jupyter lab --no-browser --port=8889" # ssh -N -L localhost:88 ruff-format = "ruff format --force-exclude" ruff-lint = "ruff check --fix --exit-non-zero-on-fix --force-exclude" r = { cmd = "echo OK", depends-on = ["ruff-format", "ruff-lint"] } +test = "python -m pytest tests/ -v" +test-fast = "python -m pytest tests/ -v -x" pymolrc = "echo \"#!/usr/bin/env python3\nimport sys; import os; sys.path.append(os.path.expanduser('\"$(pwd)/pymol-plugins/\"')); import pymol_plugins; from pymol_plugins import *;\" > $HOME/.pymolrc" diff --git a/src/utils/contactmap.py b/src/analysis/__init__.py similarity index 100% rename from src/utils/contactmap.py rename to src/analysis/__init__.py diff --git a/src/utils/distmat.py b/src/analysis/contactmap.py similarity index 100% rename from src/utils/distmat.py rename to src/analysis/contactmap.py diff --git a/src/utils/place_solvent.py b/src/analysis/distmat.py similarity index 100% rename from src/utils/place_solvent.py rename to src/analysis/distmat.py diff --git a/src/utils/extract_ave_str.py b/src/analysis/extract_ave_str.py similarity index 88% rename from src/utils/extract_ave_str.py rename to src/analysis/extract_ave_str.py index 2301356..20f9b86 100644 --- a/src/utils/extract_ave_str.py +++ b/src/analysis/extract_ave_str.py @@ -49,6 +49,8 @@ def add_subcmd(subparsers): help="Index file (.ndx)", ) + parser.set_defaults(func=run) + def run(args): if args.gmx: @@ -63,8 +65,8 @@ def run(args): import mdtraj as md trj = md.load(args.trajectory, top=args.topology) - ave_crds = trj.xyz.mean(axis=0) + ave_xyz = trj.xyz.mean(axis=0, keepdims=True) # (1, n_atoms, 3) + avg_trj = md.Trajectory(ave_xyz, trj.topology) atom_indices = trj.top.select(args.selection) - final_trj = ave_crds[atom_indices] - final_trj.save_pdb(args.output) + avg_trj.atom_slice(atom_indices).save_pdb(args.output) LOGGER.info("Done") diff --git a/src/utils/extract_str.py b/src/analysis/extract_str.py similarity index 98% rename from src/utils/extract_str.py rename to src/analysis/extract_str.py index d98c88b..8a9c888 100644 --- a/src/utils/extract_str.py +++ b/src/analysis/extract_str.py @@ -55,6 +55,8 @@ def add_subcmd(subparsers): help="Index file (.ndx)", ) + parser.set_defaults(func=run) + def run(args): if args.gmx: diff --git a/src/build/__init__.py b/src/build/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/add_ndx.py b/src/build/add_ndx.py similarity index 98% rename from src/utils/add_ndx.py rename to src/build/add_ndx.py index 5c1ef11..c23fb10 100644 --- a/src/utils/add_ndx.py +++ b/src/build/add_ndx.py @@ -38,6 +38,8 @@ def add_subcmd(subparsers): "-o", "--output", default="index.ndx", type=str, help="Output index file" ) + parser.set_defaults(func=run) + def make_default_index(args): cmd = f"echo q | gmx make_ndx -f {args.gro} -o {args.output}" diff --git a/src/utils/addace.py b/src/build/addace.py similarity index 97% rename from src/utils/addace.py rename to src/build/addace.py index 4729fa9..32f755b 100644 --- a/src/utils/addace.py +++ b/src/build/addace.py @@ -25,6 +25,8 @@ def add_subcmd(subparsers): "-o", "--output_prefix", default="out_ace", type=str, help="Output file prefix" ) + parser.set_defaults(func=run) + def run(args): cmd.load(args.structure, "target") diff --git a/src/build/addh.py b/src/build/addh.py new file mode 100644 index 0000000..4e2f723 --- /dev/null +++ b/src/build/addh.py @@ -0,0 +1,58 @@ +import argparse +import subprocess + +from ..logger import generate_logger + +LOGGER = generate_logger(__name__) + + +def add_subcmd(subparsers): + """ + mdtbx addh -s input.pdb -o input_h --method reduce + """ + parser = subparsers.add_parser( + "addh", + help="Add hydrogen atoms to structure", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "-s", + "--structure", + required=True, + type=str, + help="Input structure file (PDB etc.)", + ) + + parser.add_argument( + "-o", + "--output_prefix", + default="out_h", + type=str, + help="Output file prefix", + ) + + parser.add_argument( + "--method", + default="reduce", + choices=["reduce", "pymol"], + help="Method for hydrogen addition", + ) + + parser.set_defaults(func=run) + + +def run(args): + output_pdb = f"{args.output_prefix}.pdb" + + if args.method == "reduce": + cmd = f"reduce -build {args.structure} > {output_pdb}" + subprocess.run(cmd, shell=True, check=True) + else: + from pymol import cmd as pymol_cmd + + pymol_cmd.load(args.structure, "target") + pymol_cmd.h_add("target") + pymol_cmd.save(output_pdb, "target") + + LOGGER.info(f"{output_pdb} generated") diff --git a/src/utils/addnme.py b/src/build/addnme.py similarity index 98% rename from src/utils/addnme.py rename to src/build/addnme.py index 3bdcabc..109edcb 100644 --- a/src/utils/addnme.py +++ b/src/build/addnme.py @@ -25,6 +25,8 @@ def add_subcmd(subparsers): "-o", "--output_prefix", default="out_nme", type=str, help="Output file prefix" ) + parser.set_defaults(func=run) + def run(args): cmd.load(args.structure, "target") diff --git a/src/utils/amb2gro.py b/src/build/amb2gro.py similarity index 98% rename from src/utils/amb2gro.py rename to src/build/amb2gro.py index d71a192..e7fd20c 100644 --- a/src/utils/amb2gro.py +++ b/src/build/amb2gro.py @@ -38,6 +38,8 @@ def add_subcmd(subparsers): "--no-editconf", action="store_true", help="Do not run gmx editconf" ) + parser.set_defaults(func=run) + def run(args): if args.type == "parmed": diff --git a/src/utils/build_solution.py b/src/build/build_solution.py similarity index 99% rename from src/utils/build_solution.py rename to src/build/build_solution.py index c2aa42f..777caed 100644 --- a/src/utils/build_solution.py +++ b/src/build/build_solution.py @@ -82,6 +82,8 @@ def add_subcmd(subparsers): "--keepfiles", action="store_true", help="Keep intermediate files" ) + parser.set_defaults(func=run) + def run(args): # tleap diff --git a/src/build/build_vacuum.py b/src/build/build_vacuum.py new file mode 100644 index 0000000..7f4763c --- /dev/null +++ b/src/build/build_vacuum.py @@ -0,0 +1,104 @@ +import argparse +import subprocess + +from ..config import SYSTEM_NAME # NOQA +from ..logger import generate_logger + +LOGGER = generate_logger(__name__) + +TLEAP_TEMPLATE = """\ +source leaprc.protein.ff14SB +source leaprc.water.tip3p +{addprecmd} +{ligand_params} +{SYSTEM_NAME} = loadpdb {input} +{addpostcmd} +saveamberparm {SYSTEM_NAME} {outdir}/leap.parm7 {outdir}/leap.rst7 +savepdb {SYSTEM_NAME} {outdir}/leap.pdb +quit +""" + + +def add_subcmd(subparsers): + """ + mdtbx build_vacuum -i input.pdb -o ./ + """ + parser = subparsers.add_parser( + "build_vacuum", + help="Build vacuum system (no water, no ions, no box)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "-i", + "--input", + required=True, + type=str, + help="Input PDB file", + ) + + parser.add_argument( + "-o", + "--outdir", + default="./", + type=str, + help="Output directory", + ) + + parser.add_argument( + "--ligparam", + type=str, + help="Ligand parameter in FRCMOD:LIB format", + ) + + parser.add_argument( + "--addprecmd", + type=str, + help="Additional tleap commands before loadpdb (e.g. source leaprc.GLYCAM_06j-1)", + ) + + parser.add_argument( + "--addpostcmd", + type=str, + help="Additional tleap commands after loadpdb (e.g. SS-bond settings)", + ) + + parser.add_argument( + "--keepfiles", + action="store_true", + help="Keep intermediate files (tleap.in, leap.log)", + ) + + parser.set_defaults(func=run) + + +def run(args): + addprecmd = args.addprecmd if args.addprecmd else "" + addpostcmd = args.addpostcmd if args.addpostcmd else "" + + ligand_params = "" + if args.ligparam: + frcmod, lib = args.ligparam.split(":") + ligand_params = f"loadamberparams {frcmod}\nloadoff {lib}" + + tleap_input = TLEAP_TEMPLATE.format( + addprecmd=addprecmd, + ligand_params=ligand_params, + SYSTEM_NAME=SYSTEM_NAME, + input=args.input, + addpostcmd=addpostcmd, + outdir=args.outdir, + ) + + with open("tleap.in", "w") as f: + f.write(tleap_input) + + subprocess.run("tleap -f tleap.in", shell=True, check=True) + + LOGGER.info( + f"{args.outdir}/leap.parm7 {args.outdir}/leap.rst7 {args.outdir}/leap.pdb generated" + ) + + if not args.keepfiles: + subprocess.run("rm -f leap.log tleap.in", shell=True, check=True) + LOGGER.info("leap.log tleap.in removed") diff --git a/src/utils/calc_ion_conc.py b/src/build/calc_ion_conc.py similarity index 76% rename from src/utils/calc_ion_conc.py rename to src/build/calc_ion_conc.py index 7542581..bfabc84 100644 --- a/src/utils/calc_ion_conc.py +++ b/src/build/calc_ion_conc.py @@ -57,6 +57,14 @@ def add_subcmd(subparsers): # cubic: calculate from system volume(assume cubic system) # water: calculate from water volume(recomended if lipid system) # optimize: consider charge of system + parser.add_argument( + "--net_charge", + default=0, + type=int, + help="Net charge of the solute (used with --method optimize)", + ) + + parser.set_defaults(func=run) def get_boxsize_from_pdb(args) -> tuple[float, float, float]: # angstrom^3 @@ -112,7 +120,24 @@ def run(args): # volume = num_water * TIP3P_VOLUME * 1000 LOGGER.info(f"Volume of water molecules: {volume}") ionnum = calc_ion_conc_from_volume(volume, args.concentration) - else: - raise NotImplementedError + else: # optimize: neutralize system charge then add salt + boxsize = get_boxsize_from_pdb(args) + volume = boxsize[0] * boxsize[1] * boxsize[2] + LOGGER.info(f"Volume of system: {volume}") + net_charge = getattr(args, "net_charge", 0) + salt_num = calc_ion_conc_from_volume(volume, args.concentration) + # Add counter-ions to neutralize, then add salt ions + if net_charge > 0: + cation_num = salt_num + anion_num = salt_num + net_charge + elif net_charge < 0: + cation_num = salt_num + abs(net_charge) + anion_num = salt_num + else: + cation_num = salt_num + anion_num = salt_num + LOGGER.info(f"Net charge: {net_charge}") + LOGGER.info(f"Cation num: {cation_num}, Anion num: {anion_num}") + ionnum = cation_num # report cation count as primary value LOGGER.info(f"Number of ions that should be added: # {ionnum}") print(f"ionnum: {ionnum}") diff --git a/src/utils/centering_gro.py b/src/build/centering_gro.py similarity index 98% rename from src/utils/centering_gro.py rename to src/build/centering_gro.py index 50f32ab..d68bffd 100644 --- a/src/utils/centering_gro.py +++ b/src/build/centering_gro.py @@ -52,6 +52,8 @@ def add_subcmd(subparsers): "--no-editconf", action="store_true", help="Do not run gmx editconf" ) + parser.set_defaults(func=run) + def run(args): dummy_mdp_path = Path(__file__).parent / "dummy.mdp" diff --git a/src/utils/find_bond.py b/src/build/find_bond.py similarity index 73% rename from src/utils/find_bond.py rename to src/build/find_bond.py index b7ceea4..69f61a6 100644 --- a/src/utils/find_bond.py +++ b/src/build/find_bond.py @@ -56,6 +56,15 @@ def add_subcmd(subparsers): help="Output file name", ) + parser.add_argument( + "-op", + "--output-pdb", + type=str, + help="Output PDB file with CYS renamed to CYM for SS-bonded residues", + ) + + parser.set_defaults(func=run) + def run(args): cmd.load(args.structure, "target") @@ -86,3 +95,18 @@ def run(args): else: LOGGER.info(f"{len(bonds)} bonds found") print(bonds_str) + + if args.output_pdb is not None: + if len(bonds) == 0: + LOGGER.info("No SS-bond found; saving original structure to output PDB") + cmd.save(args.output_pdb, "target") + else: + bonded_resi = set() + for res1, res2 in bonds: + bonded_resi.add(res1) + bonded_resi.add(res2) + for resi in bonded_resi: + cmd.alter(f"target and resn CYS and resi {resi}", "resn='CYM'") + LOGGER.info(f"CYS resi {resi} renamed to CYM") + cmd.save(args.output_pdb, "target") + LOGGER.info(f"{args.output_pdb} generated") diff --git a/src/utils/gen_am1bcc.py b/src/build/gen_am1bcc.py similarity index 98% rename from src/utils/gen_am1bcc.py rename to src/build/gen_am1bcc.py index a4b0ae4..b3ea3ec 100644 --- a/src/utils/gen_am1bcc.py +++ b/src/build/gen_am1bcc.py @@ -40,6 +40,8 @@ def add_subcmd(subparsers): parser.add_argument("-c", "--charge", default=0, type=int, help="Charge") + parser.set_defaults(func=run) + def run(args): filetype = Path(args.structure).suffix[1:] diff --git a/src/utils/gen_distres.py b/src/build/gen_distres.py similarity index 99% rename from src/utils/gen_distres.py rename to src/build/gen_distres.py index e1a384d..4e3c7d6 100644 --- a/src/utils/gen_distres.py +++ b/src/build/gen_distres.py @@ -77,6 +77,8 @@ def add_subcmd(subparsers): help="Upper bound2 [nm] Use multiple value if multiple selection. Single value will be applied to all selection if single value is given", ) + parser.set_defaults(func=run) + def run(args): # generate posres.itp diff --git a/src/utils/gen_modres_am1bcc.py b/src/build/gen_modres_am1bcc.py similarity index 99% rename from src/utils/gen_modres_am1bcc.py rename to src/build/gen_modres_am1bcc.py index 7f4b95a..e1619f0 100644 --- a/src/utils/gen_modres_am1bcc.py +++ b/src/build/gen_modres_am1bcc.py @@ -58,6 +58,8 @@ def add_subcmd(subparsers): "--posttailtype", default="N", type=str, help="Atom name of Posttail" ) + parser.set_defaults(func=run) + def run(args): # ref: https://ambermd.org/tutorials/basic/tutorial5/index.php diff --git a/src/utils/gen_modres_resp.py b/src/build/gen_modres_resp.py similarity index 99% rename from src/utils/gen_modres_resp.py rename to src/build/gen_modres_resp.py index 55c2051..edee9da 100644 --- a/src/utils/gen_modres_resp.py +++ b/src/build/gen_modres_resp.py @@ -79,6 +79,8 @@ def add_subcmd(subparsers): "--threads", default=16, type=int, help="Number of threads for Gaussian" ) + parser.set_defaults(func=run) + def run(args): # ref: https://qiita.com/tacoma/items/02474d9aaa99b903e4ee diff --git a/src/utils/gen_posres.py b/src/build/gen_posres.py similarity index 96% rename from src/utils/gen_posres.py rename to src/build/gen_posres.py index e6aca00..c8337ac 100644 --- a/src/utils/gen_posres.py +++ b/src/build/gen_posres.py @@ -1,7 +1,7 @@ import argparse -from .atom_selection_parser import AtomSelector -from .parse_top import GromacsTopologyParser +from ..utils.atom_selection_parser import AtomSelector +from ..utils.parse_top import GromacsTopologyParser from ..config import * # NOQA from ..logger import generate_logger @@ -40,6 +40,8 @@ def add_subcmd(subparsers): help="Output file prefix (This also will be constant name)", ) + parser.set_defaults(func=run) + def run(args): selector = AtomSelector(args.selection) diff --git a/src/utils/gen_resp.py b/src/build/gen_resp.py similarity index 97% rename from src/utils/gen_resp.py rename to src/build/gen_resp.py index 720dc1d..eeda9b3 100644 --- a/src/utils/gen_resp.py +++ b/src/build/gen_resp.py @@ -52,6 +52,8 @@ def add_subcmd(subparsers): "--no-opt", action="store_true", help="Do not optimize structure" ) + parser.set_defaults(func=run) + def run(args): filetype = Path(args.structure).suffix[1:] @@ -92,7 +94,7 @@ def run(args): LOGGER.info("structure_optimization.log generated") # single point - cmd = f"obabel -i {GAUSSIAN_CMD} structure_optimization.log -o gjf > single_point_calculation.gjf" # NOQA + cmd = f"obabel -i gout structure_optimization.log -o gjf > single_point_calculation.gjf" # NOQA subprocess.run(cmd, shell=True, check=True) else: diff --git a/src/utils/gen_temperatures.py b/src/build/gen_temperatures.py similarity index 99% rename from src/utils/gen_temperatures.py rename to src/build/gen_temperatures.py index 314a1b9..e2621ce 100644 --- a/src/utils/gen_temperatures.py +++ b/src/build/gen_temperatures.py @@ -80,6 +80,8 @@ def add_subcmd(subparsers): help="Simulation type: 0=NPT, 1=NVT (Note: only NPT is fully supported)", ) + parser.set_defaults(func=run) + def calc_mu(nw, np_val, temp, fener): return (A0 + A1 * temp) * nw + (B0 + B1 * temp) * np_val - temp * fener diff --git a/src/utils/modeling_cf.py b/src/build/modeling_cf.py similarity index 97% rename from src/utils/modeling_cf.py rename to src/build/modeling_cf.py index 117ff2b..371b642 100644 --- a/src/utils/modeling_cf.py +++ b/src/build/modeling_cf.py @@ -24,6 +24,8 @@ def add_subcmd(subparsers): "-s", "--sequence", required=True, type=str, help="amino acid sequence" ) + parser.set_defaults(func=run) + def run(args): # colabfold command check diff --git a/src/utils/mutate.py b/src/build/mutate.py similarity index 100% rename from src/utils/mutate.py rename to src/build/mutate.py diff --git a/src/utils/mv_crds_mol2.py b/src/build/mv_crds_mol2.py similarity index 98% rename from src/utils/mv_crds_mol2.py rename to src/build/mv_crds_mol2.py index b1f3909..62de1cf 100644 --- a/src/utils/mv_crds_mol2.py +++ b/src/build/mv_crds_mol2.py @@ -40,6 +40,8 @@ def add_subcmd(subparsers): help="Output mol2 file", ) + parser.set_defaults(func=run) + def run(args): atomname2crds = {} diff --git a/src/build/place_solvent.py b/src/build/place_solvent.py new file mode 100644 index 0000000..5fbbfac --- /dev/null +++ b/src/build/place_solvent.py @@ -0,0 +1,601 @@ +# References: +# - Tutorial 34: Solvation with 3D-RISM +# https://ambermd.org/tutorials/advanced/tutorial34/index.html +# - Tutorial 40: 1D-RISM and 3D-RISM +# https://ambermd.org/tutorials/advanced/tutorial40/index.php + +import argparse +import os +import shutil +import subprocess +import tempfile +from pathlib import Path + +import numpy as np + +from ..logger import generate_logger + +LOGGER = generate_logger(__name__) + +# sander 用 3D-RISM 入力テンプレート +SANDER_RISM_INPUT_TEMPLATE = """\ +&cntrl + ntx=1, nstlim=0, irism=1, +/ +&rism + closure='{closure}', + grdspc={grdspc},{grdspc},{grdspc}, + tolerance={tolerance}, + buffer={buffer}, + solvcut={solvcut}, + mdiis_del=0.7, + mdiis_nvec=5, + maxstep=10000, + npropagate=5, + verbose=2, + apply_rism_force=0, + volfmt='dx', + ntwrism=1, +/ +""" + + +def add_subcmd(subparsers): + """ + mdtbx place_solvent -p leap.parm7 -x leap.rst7 -o solvent_placed.pdb + """ + parser = subparsers.add_parser( + "place_solvent", + help="Place solvent molecules using 3D-RISM", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "-p", + "--prmtop", + required=True, + type=str, + help="AMBER topology file (.parm7)", + ) + + parser.add_argument( + "-x", + "--coord", + required=True, + type=str, + help="AMBER coordinate file (.rst7)", + ) + + parser.add_argument( + "-o", + "--output", + default="solvent_placed.pdb", + type=str, + help="Output PDB file", + ) + + parser.add_argument( + "--xvv", + default=None, + type=str, + help="Pre-computed solvent susceptibility file (.xvv). " + "If not provided, rism1d is run to generate one.", + ) + + parser.add_argument( + "--solvent", + default="water", + choices=["water"], + help="Solvent type", + ) + + parser.add_argument( + "--solvent-model", + default="SPC", + type=str, + help="Solvent model for 1D-RISM (.mdl file stem in $AMBERHOME/dat/rism1d/mdl/)", + ) + + parser.add_argument( + "--temperature", + default=300.0, + type=float, + help="Temperature [K]", + ) + + parser.add_argument( + "--closure", + default="kh", + choices=["kh", "hnc", "pse2", "pse3"], + help="Closure approximation", + ) + + parser.add_argument( + "--grdspc", + default=0.5, + type=float, + help="Grid spacing [Å]", + ) + + parser.add_argument( + "--tolerance", + default=1e-5, + type=float, + help="Convergence tolerance for 3D-RISM", + ) + + parser.add_argument( + "--buffer", + default=14.0, + type=float, + help="Buffer distance around solute [Å]", + ) + + parser.add_argument( + "--solvcut", + default=14.0, + type=float, + help="Solvent cutoff distance [Å]", + ) + + parser.add_argument( + "--threshold", + default=1.5, + type=float, + help="Density threshold for peak extraction (g(r) value)", + ) + + parser.add_argument( + "--exclusion-radius", + default=2.6, + type=float, + help="Minimum distance between placed solvent sites [Å]. " + "Approximate diameter of a water molecule.", + ) + + parser.add_argument( + "--max-sites", + default=None, + type=int, + help="Maximum number of solvent sites to place. " + "If not set, all sites above threshold are placed.", + ) + + parser.add_argument( + "--use-sander", + action="store_true", + help="Use sander interface instead of rism3d.snglpnt", + ) + + parser.add_argument( + "--keepfiles", + action="store_true", + help="Keep intermediate files", + ) + + parser.set_defaults(func=run) + + +# --------------------------------------------------------------------------- +# 1D-RISM: xvv ファイル生成 +# --------------------------------------------------------------------------- + + +def _run_rism1d(solvent_model, temperature, workdir): + """1D-RISM を実行して xvv ファイルを生成する。 + + $AMBERHOME/dat/rism1d/mdl/.mdl が必要。 + Returns: + xvv_path: 生成された xvv ファイルのパス + """ + amberhome = os.environ.get("AMBERHOME", "") + mdl_path = os.path.join(amberhome, "dat", "rism1d", "mdl", f"{solvent_model}.mdl") + if not os.path.exists(mdl_path): + raise FileNotFoundError( + f"Solvent model file not found: {mdl_path}. " + f"Check $AMBERHOME and --solvent-model." + ) + + xvv_stem = f"{solvent_model}_{temperature:.2f}" + xvv_path = os.path.join(workdir, f"{xvv_stem}.xvv") + + # rism1d 入力ファイル + inp_content = ( + f"&PARAMETERS\n" + f" THEORY='DRISM', CLOSURE='KH',\n" + f" NR=16384, DR=0.025,\n" + f" OUTLST='xvv',\n" + f" NIS=20, DESSION=0.5, MDIIS_NVEC=20, MDIIS_DEL=0.3,\n" + f" TOLERANCE=1.0e-12,\n" + f" SMEAR=1, APTS=0.2,\n" + f" TEMPER={temperature},\n" + f"/\n" + f" {solvent_model}\n" + ) + inp_path = os.path.join(workdir, f"{xvv_stem}.inp") + with open(inp_path, "w") as f: + f.write(inp_content) + + rism1d_cmd = f"rism1d {xvv_stem} > {xvv_stem}.out 2>&1" + LOGGER.info(f"Running 1D-RISM to generate xvv file ({solvent_model}) ...") + subprocess.run(rism1d_cmd, shell=True, check=True, cwd=workdir) + + if not os.path.exists(xvv_path): + raise RuntimeError( + f"1D-RISM did not produce {xvv_path}. Check the output for errors." + ) + + LOGGER.info(f"Generated xvv file: {xvv_path}") + return xvv_path + + +# --------------------------------------------------------------------------- +# 3D-RISM 実行 +# --------------------------------------------------------------------------- + + +def _run_rism3d_snglpnt(prmtop, coord, xvv_path, args, workdir): + """rism3d.snglpnt コマンドラインインターフェースで 3D-RISM を実行する。""" + + prmtop_abs = os.path.abspath(prmtop) + coord_abs = os.path.abspath(coord) + xvv_abs = os.path.abspath(xvv_path) + prmtop_stem = Path(prmtop).stem + + guv_prefix = os.path.join(workdir, prmtop_stem) + + rism_cmd = [ + "rism3d.snglpnt", + "--prmtop", + prmtop_abs, + "--rst", + coord_abs, + "--xvv", + xvv_abs, + "--closure", + args.closure, + "--grdspc", + f"{args.grdspc},{args.grdspc},{args.grdspc}", + "--buffer", + str(args.buffer), + "--solvcut", + str(args.solvcut), + "--tolerance", + str(args.tolerance), + "--verbose", + "2", + "--ntwrism", + "1", + "--guv", + guv_prefix, + ] + + LOGGER.info("Running rism3d.snglpnt ...") + LOGGER.info(f" Command: {' '.join(rism_cmd)}") + subprocess.run(rism_cmd, check=True, cwd=workdir) + + +def _run_sander_rism(prmtop, coord, args, workdir): + """sander インターフェースで 3D-RISM を実行する。""" + + prmtop_abs = os.path.abspath(prmtop) + coord_abs = os.path.abspath(coord) + + mdin_content = SANDER_RISM_INPUT_TEMPLATE.format( + closure=args.closure, + grdspc=args.grdspc, + tolerance=args.tolerance, + buffer=args.buffer, + solvcut=args.solvcut, + ) + mdin_path = os.path.join(workdir, "mdin.rism") + with open(mdin_path, "w") as f: + f.write(mdin_content) + + sander_cmd = [ + "sander", + "-O", + "-i", + mdin_path, + "-o", + os.path.join(workdir, "mdout"), + "-p", + prmtop_abs, + "-c", + coord_abs, + ] + + LOGGER.info("Running sander with 3D-RISM ...") + subprocess.run(sander_cmd, check=True, cwd=workdir) + + +# --------------------------------------------------------------------------- +# DX ファイル読み込み +# --------------------------------------------------------------------------- + + +def _parse_dx(dx_path): + """OpenDX 形式の密度グリッドを ndarray として読み込む。 + + Returns: + data: (nx, ny, nz) ndarray — g(r) 分布関数値 + origin: (3,) ndarray — グリッド原点座標 [Å] + delta: (3, 3) ndarray — 各軸のグリッド刻み幅ベクトル (行ごと) + """ + origin = None + delta = [] + counts = None + data_values = [] + + with open(dx_path) as f: + for line in f: + line = line.strip() + if line.startswith("#") or not line: + continue + if line.startswith("object 1"): + parts = line.split() + counts = tuple(int(x) for x in parts[-3:]) + elif line.startswith("origin"): + parts = line.split() + origin = np.array([float(x) for x in parts[1:4]]) + elif line.startswith("delta"): + parts = line.split() + delta.append([float(x) for x in parts[1:4]]) + elif line.startswith("object") or line.startswith("attribute"): + continue + elif line.startswith("component"): + continue + else: + # データ行 + data_values.extend(float(v) for v in line.split()) + + if counts is None: + raise ValueError(f"Could not parse grid dimensions from {dx_path}") + if origin is None: + raise ValueError(f"Could not parse origin from {dx_path}") + if len(delta) != 3: + raise ValueError(f"Expected 3 delta vectors, got {len(delta)}") + + data = np.array(data_values).reshape(counts) + delta = np.array(delta) + + return data, origin, delta + + +def _grid_to_cartesian(indices, origin, delta): + """グリッドインデックス (N, 3) を直交座標 (N, 3) に変換する。 + + OpenDX の delta 行列は行ごとに各軸方向の刻み幅ベクトルを持つ。 + 座標 = origin + i * delta[0] + j * delta[1] + k * delta[2] + = origin + indices @ delta + """ + return origin + indices @ delta + + +# --------------------------------------------------------------------------- +# Placevent 風グリーディピーク抽出 +# --------------------------------------------------------------------------- + + +def _extract_peaks_greedy( + data, origin, delta, threshold, exclusion_radius, max_sites=None +): + """Placevent アルゴリズムに基づくグリーディな溶媒サイト抽出。 + + 1. g(r) > threshold のグリッド点をすべて候補とする + 2. g(r) 値の降順にソートする + 3. 最も高い g(r) の点を溶媒サイトとして採用し、 + exclusion_radius 以内の他の候補をすべて除外する + 4. 残りの候補で最も高い g(r) の点を次のサイトとする + 5. max_sites に達するか候補がなくなるまで繰り返す + + Returns: + coords: (M, 3) ndarray — 溶媒サイト座標 [Å] + gvalues: (M,) ndarray — 各サイトの g(r) 値 + """ + # 閾値を超えるグリッド点のインデックスと値を取得 + indices = np.argwhere(data > threshold) + if len(indices) == 0: + LOGGER.warning( + f"No grid points exceed threshold g(r) > {threshold}. " + "Try lowering --threshold." + ) + return np.empty((0, 3)), np.empty(0) + + values = data[indices[:, 0], indices[:, 1], indices[:, 2]] + + # g(r) 降順でソート + order = np.argsort(-values) + indices = indices[order] + values = values[order] + + # 座標に変換 + all_coords = _grid_to_cartesian(indices.astype(float), origin, delta) + + # グリーディ選択 + placed_coords = [] + placed_gvalues = [] + used = np.zeros(len(all_coords), dtype=bool) + excl_sq = exclusion_radius**2 + + for i in range(len(all_coords)): + if used[i]: + continue + + coord_i = all_coords[i] + placed_coords.append(coord_i) + placed_gvalues.append(values[i]) + + if max_sites is not None and len(placed_coords) >= max_sites: + break + + # この点から exclusion_radius 以内の候補を除外 + remaining = np.where(~used)[0] + remaining = remaining[remaining > i] + if len(remaining) > 0: + diff = all_coords[remaining] - coord_i + dist_sq = np.sum(diff**2, axis=1) + too_close = remaining[dist_sq < excl_sq] + used[too_close] = True + + coords = np.array(placed_coords) + gvalues = np.array(placed_gvalues) + + LOGGER.info( + f"Extracted {len(coords)} solvent sites " + f"(threshold={threshold}, exclusion_radius={exclusion_radius} Å)" + ) + return coords, gvalues + + +# --------------------------------------------------------------------------- +# PDB 出力 +# --------------------------------------------------------------------------- + + +def _write_pdb(coords, gvalues, solvent, output_path): + """溶媒サイト座標を PDB 形式で書き出す。 + + occupancy に g(r) の初期値を、B-factor に配置順の g(r) を記録する。 + 原子番号・残基番号が PDB フォーマットの上限を超える場合は + モジュロで折り返す。 + """ + atom_name = "O" if solvent == "water" else "X" + resname = "WAT" if solvent == "water" else "SOL" + + with open(output_path, "w") as f: + f.write( + "REMARK Generated by place_solvent (3D-RISM)\n" + f"REMARK {len(coords)} solvent sites placed\n" + ) + for i, (coord, gval) in enumerate(zip(coords, gvalues), start=1): + x, y, z = coord + serial = i % 100000 # ATOM serial は 5 桁まで + resseq = i % 10000 # 残基番号は 4 桁まで + f.write( + f"HETATM{serial:5d} {atom_name:<3s} {resname} A" + f"{resseq:4d} " + f"{x:8.3f}{y:8.3f}{z:8.3f}" + f"{1.00:6.2f}{gval:6.2f}\n" + ) + f.write("END\n") + + LOGGER.info(f"{len(coords)} solvent sites written to {output_path}") + + +# --------------------------------------------------------------------------- +# DX ファイル検索 +# --------------------------------------------------------------------------- + + +def _find_oxygen_dx(workdir, prmtop_stem, closure): + """3D-RISM が出力した酸素密度の .dx ファイルを見つける。 + + rism3d.snglpnt --guv prefix の場合: + prefix.O.1.dx (guv 出力) + sander の場合: + ..O.0.dx (guv 出力, 0-indexed) + + いずれも酸素サイト "O" を含むファイルを探す。 + """ + workpath = Path(workdir) + + # guv 出力パターン(rism3d.snglpnt) + # prefix.O.1.dx が典型的 + candidates = sorted(workpath.glob(f"{prmtop_stem}*O*.dx")) + + if not candidates: + # sander 出力パターン + candidates = sorted(workpath.glob(f"*{closure}*O*.dx")) + + if not candidates: + # フォールバック: 任意の酸素を含む dx + candidates = sorted(workpath.glob("*O*.dx")) + + if not candidates: + # 最終手段: 全 dx ファイル + candidates = sorted(workpath.glob("*.dx")) + + if not candidates: + raise FileNotFoundError( + f"No .dx output files found in {workdir}. " + "3D-RISM calculation may have failed." + ) + + # 水の酸素に最も適合するファイルを選択 + # "guv" や "g" を含むもの (分布関数) を優先し、 + # "cuv" (直接相関) や "huv" (間接相関) を避ける + for cand in candidates: + name = cand.name.lower() + if "cuv" in name or "huv" in name or "uuv" in name: + continue + return cand + + # すべて除外された場合は最初のものを使う + return candidates[0] + + +# --------------------------------------------------------------------------- +# メイン +# --------------------------------------------------------------------------- + + +def run(args): + prmtop_stem = Path(args.prmtop).stem + + # 作業ディレクトリ + workdir = tempfile.mkdtemp(prefix="rism3d_") + LOGGER.info(f"Working directory: {workdir}") + + try: + # ---- 1. xvv ファイル準備 ---- + if args.xvv is not None: + xvv_path = os.path.abspath(args.xvv) + if not os.path.exists(xvv_path): + LOGGER.error(f"xvv file not found: {xvv_path}") + return + LOGGER.info(f"Using provided xvv file: {xvv_path}") + else: + xvv_path = _run_rism1d(args.solvent_model, args.temperature, workdir) + + # ---- 2. 3D-RISM 実行 ---- + if args.use_sander: + _run_sander_rism(args.prmtop, args.coord, args, workdir) + else: + _run_rism3d_snglpnt(args.prmtop, args.coord, xvv_path, args, workdir) + + # ---- 3. 酸素密度 DX ファイル読み込み ---- + dx_path = _find_oxygen_dx(workdir, prmtop_stem, args.closure) + LOGGER.info(f"Reading density from {dx_path}") + + data, origin, delta = _parse_dx(str(dx_path)) + LOGGER.info( + f"Grid dimensions: {data.shape}, " + f"origin: {origin}, " + f"spacing: {np.diag(delta)}" + ) + + # ---- 4. ピーク抽出 (Placevent 風グリーディ) ---- + coords, gvalues = _extract_peaks_greedy( + data, + origin, + delta, + threshold=args.threshold, + exclusion_radius=args.exclusion_radius, + max_sites=args.max_sites, + ) + + if len(coords) == 0: + LOGGER.error("No solvent sites found.") + return + + # ---- 5. PDB 出力 ---- + output_path = os.path.abspath(args.output) + _write_pdb(coords, gvalues, args.solvent, output_path) + + finally: + if not args.keepfiles: + shutil.rmtree(workdir, ignore_errors=True) + LOGGER.info("Intermediate files removed") + else: + LOGGER.info(f"Intermediate files kept in {workdir}") diff --git a/src/cli.py b/src/cli.py index 41dd182..345be9f 100644 --- a/src/cli.py +++ b/src/cli.py @@ -10,36 +10,41 @@ from .utils import mod_mdp from .utils import shell_hook from .utils import cmd -from .utils import mv_crds_mol2 -from .utils import print_perf from .utils import show_mdtraj from .utils import show_npy +from .utils import partial_tempering # build utils -from .utils import addace -from .utils import addnme -from .utils import add_ndx -from .utils import calc_ion_conc -from .utils import centering_gro -from .utils import find_bond -from .utils import gen_am1bcc -from .utils import gen_resp -from .utils import gen_modres_am1bcc -from .utils import gen_modres_resp -from .utils import gen_posres -from .utils import gen_distres -from .utils import modeling_cf -from .utils import amb2gro -from .utils import build_solution -from .utils import partial_tempering -from .utils import gen_temperatures +from .build import addace +from .build import addh +from .build import addnme +from .build import add_ndx +from .build import mv_crds_mol2 +from .build import calc_ion_conc +from .build import centering_gro +from .build import find_bond +from .build import gen_am1bcc +from .build import gen_resp +from .build import gen_modres_am1bcc +from .build import gen_modres_resp +from .build import gen_posres +from .build import gen_distres +from .build import modeling_cf +from .build import amb2gro +from .build import build_solution +from .build import build_vacuum +from .build import place_solvent +from .build import gen_temperatures + +# trajectory utils +from .trajectory import trjcat +from .trajectory import fit +from .trajectory import pacs_trjcat +from .trajectory import print_perf # analysis utils -from .utils import trjcat -from .utils import fit -from .utils import pacs_trjcat -from .utils import extract_ave_str -from .utils import extract_str +from .analysis import extract_ave_str +from .analysis import extract_str from .cv import comdist from .cv import comvec @@ -76,6 +81,7 @@ def cli() -> None: subparsers = parser.add_subparsers() addace.add_subcmd(subparsers) + addh.add_subcmd(subparsers) addnme.add_subcmd(subparsers) add_ndx.add_subcmd(subparsers) mv_crds_mol2.add_subcmd(subparsers) @@ -108,6 +114,8 @@ def cli() -> None: # build_membrane.add_subcmd(subparsers) build_solution.add_subcmd(subparsers) + build_vacuum.add_subcmd(subparsers) + place_solvent.add_subcmd(subparsers) comdist.add_subcmd(subparsers) comvec.add_subcmd(subparsers) @@ -120,132 +128,10 @@ def cli() -> None: args = parser.parse_args() - if len(sys.argv) == 1: + if not hasattr(args, "func"): LOGGER.error(f"use {sys.argv[0]} --help") sys.exit(1) - LOGGER.info(f"{sys.argv[1]} called") - - if sys.argv[1] == "rmfile": - rmfile.run(args) - - elif sys.argv[1] == "addace": - addace.run(args) - - elif sys.argv[1] == "addnme": - addnme.run(args) - - elif sys.argv[1] == "mv_crds_mol2": - mv_crds_mol2.run(args) - - elif sys.argv[1] == "trjcat": - trjcat.run(args) - - elif sys.argv[1] == "fit": - fit.run(args) - - elif sys.argv[1] == "pacs_trjcat": - pacs_trjcat.run(args) - - elif sys.argv[1] == "convert": - convert.run(args) - - elif sys.argv[1] == "centering_gro": - centering_gro.run(args) - - elif sys.argv[1] == "amb2gro": - amb2gro.run(args) - - elif sys.argv[1] == "find_bond": - find_bond.run(args) - - elif sys.argv[1] == "mod_mdp": - mod_mdp.run(args) - - elif sys.argv[1] == "gen_am1bcc": - gen_am1bcc.run(args) - - elif sys.argv[1] == "gen_resp": - gen_resp.run(args) - - elif sys.argv[1] == "gen_modres_am1bcc": - gen_modres_am1bcc.run(args) - - elif sys.argv[1] == "gen_modres_resp": - gen_modres_resp.run(args) - - elif sys.argv[1] == "gen_posres": - gen_posres.run(args) - - elif sys.argv[1] == "gen_distres": - gen_distres.run(args) - - elif sys.argv[1] == "modeling_cf": - modeling_cf.run(args) - - elif sys.argv[1] == "add_ndx": - add_ndx.run(args) - - elif sys.argv[1] == "extract_ave_str": - extract_ave_str.run(args) - - elif sys.argv[1] == "extract_str": - extract_str.run(args) - - elif sys.argv[1] == "show_mdtraj": - show_mdtraj.run(args) - - elif sys.argv[1] == "show_npy": - show_npy.run(args) - - elif sys.argv[1] == "print_perf": - print_perf.run(args) - - elif sys.argv[1] == "shell_hook": - shell_hook.run(args) - - elif sys.argv[1] == "cmd": - cmd.run(args) - - elif sys.argv[1] == "partial_tempering": - partial_tempering.run(args) - - elif sys.argv[1] == "gen_temperatures": - gen_temperatures.run(args) - - elif sys.argv[1] == "calc_ion_conc": - calc_ion_conc.run(args) - - # elif sys.argv[1] == "build_membrane": - # build_membrane.run(args) - - elif sys.argv[1] == "build_solution": - build_solution.run(args) - - elif sys.argv[1] == "comdist": - comdist.run(args) - - elif sys.argv[1] == "comvec": - comvec.run(args) - - elif sys.argv[1] == "mindist": - mindist.run(args) - - elif sys.argv[1] == "rmsd": - rmsd.run(args) - - elif sys.argv[1] == "rmsf": - rmsf.run(args) - - elif sys.argv[1] == "pca": - pca.run(args) - - elif sys.argv[1] == "xyz": - xyz.run(args) - - elif sys.argv[1] == "densmap": - densmap.run(args) - - else: - print(f"Unknown command: {sys.argv[1]}") + LOGGER.info(f"{sys.argv[1]} called") + args.func(args) LOGGER.info(f"{sys.argv[1]} finished") diff --git a/src/cv/comdist.py b/src/cv/comdist.py index 569d899..398f295 100644 --- a/src/cv/comdist.py +++ b/src/cv/comdist.py @@ -57,6 +57,8 @@ def add_subcmd(subparsers): help="Index file (.ndx)", ) + parser.set_defaults(func=run) + def run(args): if args.gmx: diff --git a/src/cv/comvec.py b/src/cv/comvec.py index 9494c42..75c6921 100644 --- a/src/cv/comvec.py +++ b/src/cv/comvec.py @@ -54,6 +54,8 @@ def add_subcmd(subparsers): help="Index file (.ndx)", ) + parser.set_defaults(func=run) + def run(args): if args.gmx: diff --git a/src/cv/densmap.py b/src/cv/densmap.py index fe01c90..941a316 100644 --- a/src/cv/densmap.py +++ b/src/cv/densmap.py @@ -7,14 +7,16 @@ LOGGER = generate_logger(__name__) +_AXIS_MAP = {"xy": (0, 1), "xz": (0, 2), "yz": (1, 2)} + def add_subcmd(subparsers): """ - mdtbx densmap --topology structure.pdb --trajectory trajectory.xtc --selection "resid 1 to 10" -o desmap.npy + mdtbx densmap --topology structure.pdb --trajectory trajectory.xtc --selection "resid 1 to 10" -o densmap.npy """ parser = subparsers.add_parser( "densmap", - help="Extract densmap", + help="Extract 2D density map", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) @@ -36,9 +38,18 @@ def add_subcmd(subparsers): help="Selection (MDtraj Atom selection language)", ) parser.add_argument( - "-o", "--output", type=str, default="comdist.npy", help="Output file (.npy)" + "-o", "--output", type=str, default="densmap.npy", help="Output file (.npy)" + ) + parser.add_argument( + "--bins", type=int, default=100, help="Number of bins along each axis" + ) + parser.add_argument( + "--axis", + type=str, + default="xy", + choices=["xy", "xz", "yz"], + help="Projection plane (MDtraj path only)", ) - parser.add_argument( "--gmx", action="store_true", help="Use Gromacs instead of MDtraj" ) @@ -49,6 +60,8 @@ def add_subcmd(subparsers): help="Index file (.ndx)", ) + parser.set_defaults(func=run) + def run(args): if args.gmx: @@ -64,8 +77,27 @@ def run(args): # dens = a[1:, 1] # c = ax.pcolormesh(X, Y, dens) else: - # trj = md.load(args.trajectory, top=args.topology) - raise NotImplementedError + import mdtraj as md + + trj = md.load(args.trajectory, top=args.topology) + atom_indices = trj.topology.select(args.selection) + if len(atom_indices) == 0: + LOGGER.error(f"No atoms selected by: {args.selection}") + return + + xyz = trj.xyz[:, atom_indices, :] # (n_frames, n_atoms, 3) [nm] + ax0, ax1 = _AXIS_MAP[args.axis] + pos0 = xyz[:, :, ax0].ravel() + pos1 = xyz[:, :, ax1].ravel() + + counts, edges0, edges1 = np.histogram2d(pos0, pos1, bins=args.bins) + # Save as dict-like structured array: [counts, edges0, edges1] + # Use object array to preserve shapes + densmap = np.empty(3, dtype=object) + densmap[0] = counts + densmap[1] = edges0 + densmap[2] = edges1 + LOGGER.info(f"Density map shape: {counts.shape}, axis: {args.axis}") np.save(args.output, densmap) LOGGER.info(f"Saved to {args.output}") diff --git a/src/cv/mindist.py b/src/cv/mindist.py index d6f0334..c43fb4d 100644 --- a/src/cv/mindist.py +++ b/src/cv/mindist.py @@ -47,6 +47,8 @@ def add_subcmd(subparsers): "-o", "--output", type=str, default="mindist.npy", help="Output file (.npy)" ) + parser.set_defaults(func=run) + def run(args): trj = md.load(args.trajectory, top=args.topology) diff --git a/src/cv/pca.py b/src/cv/pca.py index c4d4862..49c5db6 100644 --- a/src/cv/pca.py +++ b/src/cv/pca.py @@ -84,6 +84,8 @@ def add_subcmd(subparsers): "-o", "--output", type=str, default="pca.npy", help="Output file (.npy)" ) + parser.set_defaults(func=run) + def run(args): if args.gmx: @@ -127,8 +129,8 @@ def run(args): trj.superpose( ref, 0, - atom_indices=atom_indices_fit_ref, - ref_atom_indices=atom_indices_fit_trj, + atom_indices=atom_indices_fit_trj, + ref_atom_indices=atom_indices_fit_ref, ) atom_indices_cal = trj.top.select(args.selection_cal_trj) pca_model = PCA(n_components=args.n_components) diff --git a/src/cv/rmsd.py b/src/cv/rmsd.py index bd07abe..92843ea 100644 --- a/src/cv/rmsd.py +++ b/src/cv/rmsd.py @@ -67,6 +67,8 @@ def add_subcmd(subparsers): "-o", "--output", type=str, default="rmsd.npy", help="Output file (.npy)" ) + parser.set_defaults(func=run) + def run(args): trj = md.load(args.trajectory, top=args.topology) diff --git a/src/cv/rmsf.py b/src/cv/rmsf.py index 9e4d2cb..8d48c4b 100644 --- a/src/cv/rmsf.py +++ b/src/cv/rmsf.py @@ -51,6 +51,8 @@ def add_subcmd(subparsers): "-o", "--output", type=str, default="rmsf.npy", help="Output file (.npy)" ) + parser.set_defaults(func=run) + def run(args): if args.gmx: diff --git a/src/cv/xyz.py b/src/cv/xyz.py index ceac6a7..c6f8037 100644 --- a/src/cv/xyz.py +++ b/src/cv/xyz.py @@ -39,6 +39,8 @@ def add_subcmd(subparsers): "-o", "--output", type=str, default="comdist.npy", help="Output file (.npy)" ) + parser.set_defaults(func=run) + def run(args): trj = md.load(args.trajectory, top=args.topology) diff --git a/src/trajectory/__init__.py b/src/trajectory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/fit.py b/src/trajectory/fit.py similarity index 98% rename from src/utils/fit.py rename to src/trajectory/fit.py index 71bbd8a..af5ae4e 100644 --- a/src/utils/fit.py +++ b/src/trajectory/fit.py @@ -52,6 +52,8 @@ def add_subcmd(subparsers): "-idx", "--index", default="index.ndx", type=str, help="Index file" ) + parser.set_defaults(func=run) + def run(args): if args.gmx: diff --git a/src/utils/opt_perf.py b/src/trajectory/opt_perf.py similarity index 100% rename from src/utils/opt_perf.py rename to src/trajectory/opt_perf.py diff --git a/src/utils/pacs_trjcat.py b/src/trajectory/pacs_trjcat.py similarity index 99% rename from src/utils/pacs_trjcat.py rename to src/trajectory/pacs_trjcat.py index 5364704..513f156 100644 --- a/src/utils/pacs_trjcat.py +++ b/src/trajectory/pacs_trjcat.py @@ -85,6 +85,8 @@ def add_subcmd(subparsers): help="Keep cycle trajectory (e.g. trial001/cycle000/prd_all.xtc)", ) + parser.set_defaults(func=run) + def check_cycle(args): cmd = f"ls {args.trial_dir} | grep cycle | wc -l" diff --git a/src/utils/print_perf.py b/src/trajectory/print_perf.py similarity index 99% rename from src/utils/print_perf.py rename to src/trajectory/print_perf.py index 403f58b..70d802a 100644 --- a/src/utils/print_perf.py +++ b/src/trajectory/print_perf.py @@ -34,6 +34,8 @@ def add_subcmd(subparsers): type=str, ) + parser.set_defaults(func=run) + def parse_log_file(log_path): data = { diff --git a/src/utils/trjcat.py b/src/trajectory/trjcat.py similarity index 99% rename from src/utils/trjcat.py rename to src/trajectory/trjcat.py index 25c55c1..867272d 100644 --- a/src/utils/trjcat.py +++ b/src/trajectory/trjcat.py @@ -58,6 +58,8 @@ def add_subcmd(subparsers): "--no-resnr", action="store_true", help="Do not run gmx editconf -resnr 1" ) + parser.set_defaults(func=run) + def run(args): # ref: https://zenn.dev/kh01734/articles/012380a58949d1 diff --git a/src/utils/addh.py b/src/utils/addh.py deleted file mode 100644 index 5b924b0..0000000 --- a/src/utils/addh.py +++ /dev/null @@ -1 +0,0 @@ -# reduce or h_add in pymol diff --git a/src/utils/cmd.py b/src/utils/cmd.py index 4201cbd..18cbaab 100644 --- a/src/utils/cmd.py +++ b/src/utils/cmd.py @@ -25,6 +25,8 @@ def add_subcmd(subparsers): help="The command and its arguments to execute.", ) + parser.set_defaults(func=run) + def run(args): if not args.command: diff --git a/src/utils/convert.py b/src/utils/convert.py index 9659717..104e2ea 100644 --- a/src/utils/convert.py +++ b/src/utils/convert.py @@ -37,6 +37,8 @@ def add_subcmd(subparsers): choices=["pymol", "mdtraj"], ) + parser.set_defaults(func=run) + def run(args): if args.type == "pymol": diff --git a/src/utils/mod_mdp.py b/src/utils/mod_mdp.py index 63a202d..eb8aa78 100644 --- a/src/utils/mod_mdp.py +++ b/src/utils/mod_mdp.py @@ -35,6 +35,8 @@ def add_subcmd(subparsers): "-lj", "--ljust", type=int, default=23, help="Ljust for new variable line" ) + parser.set_defaults(func=run) + def run(args): for mdp in Path(args.path).glob("*.mdp"): diff --git a/src/utils/partial_tempering.py b/src/utils/partial_tempering.py index eacf5b2..13a4d70 100644 --- a/src/utils/partial_tempering.py +++ b/src/utils/partial_tempering.py @@ -32,6 +32,8 @@ def add_subcmd(subparsers): "-o", "--output", type=str, default="output.top", help="Output struxture file" ) + parser.set_defaults(func=run) + def run(args): selector = AtomSelector(args.selection) diff --git a/src/utils/rmfile.py b/src/utils/rmfile.py index e57c99f..fb1898b 100644 --- a/src/utils/rmfile.py +++ b/src/utils/rmfile.py @@ -15,6 +15,8 @@ def add_subcmd(subparsers): ) parser.add_argument("--path", type=str, help="Path to the directory", default=".") + parser.set_defaults(func=run) + def run(args): for suffix in ["#*#", "*cpt", "mdout.mdp"]: diff --git a/src/utils/shell_hook.py b/src/utils/shell_hook.py index a64af04..a83d35e 100644 --- a/src/utils/shell_hook.py +++ b/src/utils/shell_hook.py @@ -11,12 +11,14 @@ def add_subcmd(subparsers): """ mdtbx shell_hook """ - _parser = subparsers.add_parser( + parser = subparsers.add_parser( "shell_hook", help="Generate shell hook", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + parser.set_defaults(func=run) + def run(args): hook = f""" diff --git a/src/utils/show_mdtraj.py b/src/utils/show_mdtraj.py index ab42532..b74e508 100644 --- a/src/utils/show_mdtraj.py +++ b/src/utils/show_mdtraj.py @@ -19,6 +19,8 @@ def add_subcmd(subparsers): parser.add_argument("topology", type=str, help="Topology file (.gro, .pdb)") + parser.set_defaults(func=run) + def run(args): top = md.load(args.topology) diff --git a/src/utils/show_npy.py b/src/utils/show_npy.py index bfc97d8..611e6d6 100644 --- a/src/utils/show_npy.py +++ b/src/utils/show_npy.py @@ -19,6 +19,8 @@ def add_subcmd(subparsers): parser.add_argument("npy", type=str, help="npy file") + parser.set_defaults(func=run) + def run(args): npy = np.load(args.npy) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1f29cf8 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,88 @@ +""" +pytest 共有 fixture と PyMOL モック設定 + +NOTE: sys.modules へのモック注入は他のどの src インポートよりも前に + 実行される必要があるため、このファイルの先頭で行う。 +""" + +import sys +from unittest.mock import MagicMock + +# config.py が `import pymol_plugins` を実行し、 +# find_bond.py / convert.py が `from pymol import cmd` を実行するため、 +# テスト環境では PyMOL GUI なしで動作できるようモックする。 +sys.modules["pymol_plugins"] = MagicMock() +sys.modules["pymol"] = MagicMock() +sys.modules["pymol.cmd"] = MagicMock() + +import pathlib # noqa: E402 + +import numpy as np # noqa: E402 +import pytest # noqa: E402 + +FIXTURES_DIR = pathlib.Path(__file__).parent / "fixtures" + + +@pytest.fixture(scope="session") +def fixtures_dir() -> pathlib.Path: + return FIXTURES_DIR + + +@pytest.fixture(scope="session") +def sample_mdp_path() -> pathlib.Path: + return FIXTURES_DIR / "sample.mdp" + + +@pytest.fixture(scope="session") +def sample_top_path() -> pathlib.Path: + return FIXTURES_DIR / "sample.top" + + +@pytest.fixture(scope="session") +def sample_pdb_path() -> pathlib.Path: + return FIXTURES_DIR / "sample.pdb" + + +@pytest.fixture(scope="session") +def trajectory_files(tmp_path_factory): + """ + mdtraj で合成した軌跡を一時ファイルとして用意する。 + ALA + GLY の 2 残基 9 原子、10 フレームの最小系。 + scope="session" で全テストセッション中に 1 回だけ生成する。 + """ + import mdtraj as md + + tmp = tmp_path_factory.mktemp("traj") + + top = md.Topology() + chain = top.add_chain() + + res1 = top.add_residue("ALA", chain) + top.add_atom("N", md.element.nitrogen, res1) + top.add_atom("CA", md.element.carbon, res1) + top.add_atom("CB", md.element.carbon, res1) + top.add_atom("C", md.element.carbon, res1) + top.add_atom("O", md.element.oxygen, res1) + + res2 = top.add_residue("GLY", chain) + top.add_atom("N", md.element.nitrogen, res2) + top.add_atom("CA", md.element.carbon, res2) + top.add_atom("C", md.element.carbon, res2) + top.add_atom("O", md.element.oxygen, res2) + + n_frames = 10 + n_atoms = top.n_atoms # 9 + + np.random.seed(42) + # 0〜2 nm の範囲でランダムな座標(各フレームで異なる位置) + xyz = np.random.rand(n_frames, n_atoms, 3) * 2.0 + + traj = md.Trajectory(xyz, top) + + pdb_path = str(tmp / "sample.pdb") + xtc_path = str(tmp / "sample.xtc") + + traj[0].save_pdb(pdb_path) + traj.save_xtc(xtc_path) + + return {"pdb": pdb_path, "xtc": xtc_path, "traj": traj} diff --git a/tests/fixtures/sample.mdp b/tests/fixtures/sample.mdp new file mode 100644 index 0000000..3c6d1c1 --- /dev/null +++ b/tests/fixtures/sample.mdp @@ -0,0 +1,9 @@ +; Minimal MDP file for testing +integrator = md +nsteps = 1000 +dt = 0.002 +nstout = 100 +; coulomb +coulombtype = PME +rcoulomb = 1.0 +rvdw = 1.0 diff --git a/tests/fixtures/sample.pdb b/tests/fixtures/sample.pdb new file mode 100644 index 0000000..b785fd0 --- /dev/null +++ b/tests/fixtures/sample.pdb @@ -0,0 +1,18 @@ +REMARK Minimal PDB for testing (box 50x50x50 Angstrom, 2 WAT molecules) +CRYST1 50.000 50.000 50.000 90.00 90.00 90.00 P 1 1 +ATOM 1 N ALA A 1 10.000 10.000 10.000 1.00 0.00 N +ATOM 2 CA ALA A 1 11.500 10.000 10.000 1.00 0.00 C +ATOM 3 CB ALA A 1 12.000 8.700 10.000 1.00 0.00 C +ATOM 4 C ALA A 1 12.500 10.900 10.000 1.00 0.00 C +ATOM 5 O ALA A 1 12.000 12.000 10.000 1.00 0.00 O +ATOM 6 N GLY A 2 13.800 10.500 10.000 1.00 0.00 N +ATOM 7 CA GLY A 2 14.900 11.400 10.000 1.00 0.00 C +ATOM 8 C GLY A 2 16.200 10.700 10.000 1.00 0.00 C +ATOM 9 O GLY A 2 16.300 9.500 10.000 1.00 0.00 O +HETATM 10 O WAT A 3 20.000 20.000 20.000 1.00 0.00 O +HETATM 11 H1 WAT A 3 20.800 20.600 20.000 1.00 0.00 H +HETATM 12 H2 WAT A 3 19.200 20.600 20.000 1.00 0.00 H +HETATM 13 O WAT A 4 25.000 25.000 25.000 1.00 0.00 O +HETATM 14 H1 WAT A 4 25.800 25.600 25.000 1.00 0.00 H +HETATM 15 H2 WAT A 4 24.200 25.600 25.000 1.00 0.00 H +END diff --git a/tests/fixtures/sample.top b/tests/fixtures/sample.top new file mode 100644 index 0000000..2ee3c2f --- /dev/null +++ b/tests/fixtures/sample.top @@ -0,0 +1,57 @@ +; Minimal GROMACS topology for testing + +[ defaults ] +; nbfunc comb-rule gen-pairs fudgeLJ fudgeQQ +1 2 yes 0.5 0.8333 + +[ atomtypes ] +; name at.num mass charge ptype sigma epsilon +CT 6 12.011 0.000 A 3.39967e-01 4.57730e-01 +OW 8 15.999 0.000 A 3.15061e-01 6.36386e-01 +HW 1 1.008 0.000 A 0.00000e+00 0.00000e+00 + +[ moleculetype ] +; Name nrexcl +Protein 3 + +[ atoms ] +; nr type resnr residue atom cgnr charge mass + 1 CT 1 ALA N 1 -0.4157 14.007 + 2 CT 1 ALA CA 2 0.0337 12.011 + 3 CT 1 ALA CB 3 -0.1823 12.011 + 4 CT 1 ALA C 4 0.5973 12.011 + 5 CT 1 ALA O 5 -0.5679 15.999 + 6 CT 2 GLY N 6 -0.4157 14.007 + 7 CT 2 GLY CA 7 0.0337 12.011 + 8 CT 2 GLY C 8 0.5973 12.011 + 9 CT 2 GLY O 9 -0.5679 15.999 + +[ bonds ] +; ai aj funct + 1 2 1 + 2 3 1 + 2 4 1 + 4 5 1 + 4 6 1 + 6 7 1 + 7 8 1 + 8 9 1 + +[ moleculetype ] +; Name nrexcl +SOL 1 + +[ atoms ] +; nr type resnr residue atom cgnr charge mass + 1 OW 1 SOL OW 1 -0.834 15.999 + 2 HW 1 SOL HW1 2 0.417 1.008 + 3 HW 1 SOL HW2 3 0.417 1.008 + +[ system ] +; Name +Test System + +[ molecules ] +; Compound #mols +Protein 1 +SOL 3 diff --git a/tests/test_analysis/__init__.py b/tests/test_analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_analysis/test_extract_ave_str.py b/tests/test_analysis/test_extract_ave_str.py new file mode 100644 index 0000000..5e60d3c --- /dev/null +++ b/tests/test_analysis/test_extract_ave_str.py @@ -0,0 +1,82 @@ +""" +analysis/extract_ave_str のユニットテスト + +MDtraj を使った平均構造抽出をテストする(gmx=False パス)。 +""" + +import types + +import numpy as np + + +class TestExtractAveStrRun: + def _make_args(self, traj_files, output, selection="all"): + return types.SimpleNamespace( + topology=traj_files["pdb"], + trajectory=traj_files["xtc"], + selection=selection, + output=str(output), + gmx=False, + index=None, + ) + + def test_output_file_created(self, trajectory_files, tmp_path): + """run() が PDB ファイルを生成すること""" + from src.analysis.extract_ave_str import run + + out = tmp_path / "ave.pdb" + run(self._make_args(trajectory_files, out)) + assert out.exists() + + def test_output_has_single_frame(self, trajectory_files, tmp_path): + """平均構造は 1 フレームのみであること""" + import mdtraj as md + + from src.analysis.extract_ave_str import run + + out = tmp_path / "ave_single.pdb" + run(self._make_args(trajectory_files, out)) + + loaded = md.load_pdb(str(out)) + assert loaded.n_frames == 1 + + def test_output_has_correct_atom_count_all(self, trajectory_files, tmp_path): + """全原子選択時の原子数が元の軌跡と一致すること""" + import mdtraj as md + + from src.analysis.extract_ave_str import run + + out = tmp_path / "ave_all.pdb" + run(self._make_args(trajectory_files, out, selection="all")) + + loaded = md.load_pdb(str(out)) + assert loaded.n_atoms == trajectory_files["traj"].n_atoms + + def test_output_has_correct_atom_count_subset(self, trajectory_files, tmp_path): + """部分選択時の原子数が選択した原子数と一致すること""" + import mdtraj as md + + from src.analysis.extract_ave_str import run + + selection = "resid 0" + out = tmp_path / "ave_subset.pdb" + run(self._make_args(trajectory_files, out, selection=selection)) + + loaded = md.load_pdb(str(out)) + n_selected = len(trajectory_files["traj"].topology.select(selection)) + assert loaded.n_atoms == n_selected + + def test_average_xyz_is_within_trajectory_range(self, trajectory_files, tmp_path): + """平均座標が元の軌跡の各軸の min/max の範囲内であること""" + import mdtraj as md + + from src.analysis.extract_ave_str import run + + out = tmp_path / "ave_range.pdb" + run(self._make_args(trajectory_files, out)) + + ave = md.load_pdb(str(out)) + traj = trajectory_files["traj"] + + assert np.all(ave.xyz >= traj.xyz.min(axis=0) - 1e-6) + assert np.all(ave.xyz <= traj.xyz.max(axis=0) + 1e-6) diff --git a/tests/test_analysis/test_extract_str.py b/tests/test_analysis/test_extract_str.py new file mode 100644 index 0000000..e2bc57f --- /dev/null +++ b/tests/test_analysis/test_extract_str.py @@ -0,0 +1,70 @@ +""" +analysis/extract_str のユニットテスト + +MDtraj を使った特定フレームの構造抽出をテストする(gmx=False パス)。 +""" + +import types + + +class TestExtractStrRun: + def _make_args(self, traj_files, output, time=1, selection="all"): + return types.SimpleNamespace( + topology=traj_files["pdb"], + trajectory=traj_files["xtc"], + selection=selection, + time=time, + output=str(output), + gmx=False, + index=None, + ) + + def test_output_file_created(self, trajectory_files, tmp_path): + """run() が PDB ファイルを生成すること""" + from src.analysis.extract_str import run + + out = tmp_path / "frame1.pdb" + run(self._make_args(trajectory_files, out, time=1)) + assert out.exists() + + def test_output_is_valid_pdb(self, trajectory_files, tmp_path): + """生成された PDB が MDtraj で読み込めること""" + import mdtraj as md + + from src.analysis.extract_str import run + + out = tmp_path / "frame1.pdb" + run(self._make_args(trajectory_files, out, time=1)) + + loaded = md.load_pdb(str(out)) + assert loaded.n_frames == 1 + + def test_output_atom_count(self, trajectory_files, tmp_path): + """抽出した構造の原子数が元の軌跡と一致すること""" + import mdtraj as md + + from src.analysis.extract_str import run + + out = tmp_path / "frame1.pdb" + run(self._make_args(trajectory_files, out, time=1)) + + loaded = md.load_pdb(str(out)) + assert loaded.n_atoms == trajectory_files["traj"].n_atoms + + def test_extract_different_frames(self, trajectory_files, tmp_path): + """異なるフレームで異なる座標が得られること""" + import mdtraj as md + + from src.analysis.extract_str import run + + out1 = tmp_path / "frame1.pdb" + out2 = tmp_path / "frame2.pdb" + run(self._make_args(trajectory_files, out1, time=1)) + run(self._make_args(trajectory_files, out2, time=2)) + + t1 = md.load_pdb(str(out1)) + t2 = md.load_pdb(str(out2)) + # ランダム座標なので 2 フレームの座標は異なるはず + import numpy as np + + assert not np.allclose(t1.xyz, t2.xyz) diff --git a/tests/test_build/__init__.py b/tests/test_build/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_build/test_add_ndx.py b/tests/test_build/test_add_ndx.py new file mode 100644 index 0000000..f884478 --- /dev/null +++ b/tests/test_build/test_add_ndx.py @@ -0,0 +1,75 @@ +""" +build/add_ndx のユニットテスト + +gmx make_ndx に依存しない純粋な count_index_group() ヘルパーをテストする。 +""" + +import types +from pathlib import Path + + +class TestCountIndexGroup: + def _make_args(self, index_path): + return types.SimpleNamespace(index=str(index_path)) + + def _write_ndx(self, path: Path, n_groups: int) -> Path: + """n_groups 個のインデックスグループを持つ .ndx ファイルを作成する""" + lines = [] + for i in range(n_groups): + lines.append(f"[ Group{i} ]\n") + lines.append(f"{i + 1}\n") + path.write_text("".join(lines)) + return path + + def test_counts_single_group(self, tmp_path): + """グループが 1 つの .ndx を正しくカウントすること""" + from src.build.add_ndx import count_index_group + + ndx = self._write_ndx(tmp_path / "single.ndx", n_groups=1) + assert count_index_group(self._make_args(ndx)) == 1 + + def test_counts_multiple_groups(self, tmp_path): + """複数グループの .ndx を正しくカウントすること""" + from src.build.add_ndx import count_index_group + + ndx = self._write_ndx(tmp_path / "multi.ndx", n_groups=5) + assert count_index_group(self._make_args(ndx)) == 5 + + def test_ignores_non_bracket_lines(self, tmp_path): + """[ で始まらない行はカウントしないこと""" + from src.build.add_ndx import count_index_group + + ndx = tmp_path / "mixed.ndx" + ndx.write_text( + "[ System ]\n" + "1 2 3 4 5\n" + "[ Protein ]\n" + "1 2 3\n" + "Some comment line\n" + "[ Water ]\n" + "4 5\n" + ) + assert count_index_group(self._make_args(ndx)) == 3 + + def test_empty_file_returns_zero(self, tmp_path): + """空ファイルは 0 を返すこと""" + from src.build.add_ndx import count_index_group + + ndx = tmp_path / "empty.ndx" + ndx.write_text("") + assert count_index_group(self._make_args(ndx)) == 0 + + def test_real_gromacs_style_ndx(self, tmp_path): + """GROMACS 形式の .ndx を正しくカウントすること""" + from src.build.add_ndx import count_index_group + + ndx = tmp_path / "gromacs.ndx" + ndx.write_text( + "[ System ]\n" + " 1 2 3 4 5 6 7 8 9\n" + "[ Protein ]\n" + " 1 2 3 4 5\n" + "[ non-Protein ]\n" + " 6 7 8 9\n" + ) + assert count_index_group(self._make_args(ndx)) == 3 diff --git a/tests/test_build/test_calc_ion_conc.py b/tests/test_build/test_calc_ion_conc.py new file mode 100644 index 0000000..408911a --- /dev/null +++ b/tests/test_build/test_calc_ion_conc.py @@ -0,0 +1,95 @@ +""" +calc_ion_conc のユニットテスト + +純粋計算関数と PDB パーサーをテストする。 +""" + +import types + +import pytest + +from src.build.calc_ion_conc import ( + calc_ion_conc_from_volume, + get_boxsize_from_pdb, + get_water_number_from_pdb, +) + +# AVOGADRO_CONST = 6.022 (config.py より) +AVOGADRO = 6.022 + + +class TestCalcIonConcFromVolume: + def test_known_value(self): + """ + volume=1e6 A^3, concentration=0.15 M のとき + ionnum = 1e6 * 0.15 * 6.022 // 10000 = 90 + """ + result = calc_ion_conc_from_volume(1e6, 0.15) + assert result == 90 + + def test_zero_concentration(self): + result = calc_ion_conc_from_volume(1e6, 0.0) + assert result == 0 + + def test_returns_int(self): + result = calc_ion_conc_from_volume(1e6, 0.15) + assert isinstance(result, int) + + def test_proportional_to_volume(self): + """体積が 2 倍になるとイオン数もほぼ 2 倍になること""" + n1 = calc_ion_conc_from_volume(1e6, 0.15) + n2 = calc_ion_conc_from_volume(2e6, 0.15) + assert abs(n2 - 2 * n1) <= 1 # 整数切り捨てによる誤差を許容 + + def test_proportional_to_concentration(self): + """濃度が 2 倍になるとイオン数もほぼ 2 倍になること""" + n1 = calc_ion_conc_from_volume(1e6, 0.10) + n2 = calc_ion_conc_from_volume(1e6, 0.20) + assert abs(n2 - 2 * n1) <= 1 + + +class TestGetBoxsizeFromPdb: + def test_reads_cryst_line(self, sample_pdb_path): + """CRYST1 行からボックスサイズが読み取れること""" + args = types.SimpleNamespace(pdb=str(sample_pdb_path)) + x, y, z = get_boxsize_from_pdb(args) + assert x == pytest.approx(50.0) + assert y == pytest.approx(50.0) + assert z == pytest.approx(50.0) + + def test_raises_if_no_cryst(self, tmp_path): + """CRYST1 行がない PDB は例外を送出すること""" + pdb = tmp_path / "no_cryst.pdb" + pdb.write_text("ATOM 1 CA ALA A 1 0.000 0.000 0.000\n") + args = types.SimpleNamespace(pdb=str(pdb)) + with pytest.raises(Exception, match="CRYST"): + get_boxsize_from_pdb(args) + + +class TestGetWaterNumberFromPdb: + def test_counts_wat_oxygens(self, sample_pdb_path): + """WAT の酸素原子数が正しくカウントされること(fixture は 2 分子)""" + args = types.SimpleNamespace(pdb=str(sample_pdb_path), water_name="WAT") + count = get_water_number_from_pdb(args) + assert count == 2 + + def test_counts_with_custom_water_name(self, tmp_path): + """water_name を変えたときに正しくカウントされること""" + pdb = tmp_path / "wat.pdb" + # "WAT" は "O" を含まないため、O 原子行のみマッチする + pdb.write_text( + "HETATM 1 O WAT A 1 0.000 0.000 0.000\n" + "HETATM 2 H1 WAT A 1 0.800 0.600 0.000\n" + "HETATM 3 H2 WAT A 1 -0.800 0.600 0.000\n" + ) + args = types.SimpleNamespace(pdb=str(pdb), water_name="WAT") + count = get_water_number_from_pdb(args) + assert count == 1 + + def test_no_water_returns_zero(self, tmp_path): + """水分子がない場合は 0 を返すこと""" + pdb = tmp_path / "protein_only.pdb" + pdb.write_text("ATOM 1 CA ALA A 1 0.000 0.000 0.000\n") + args = types.SimpleNamespace(pdb=str(pdb), water_name="WAT") + count = get_water_number_from_pdb(args) + assert count == 0 diff --git a/tests/test_build/test_gen_posres.py b/tests/test_build/test_gen_posres.py new file mode 100644 index 0000000..3c9c764 --- /dev/null +++ b/tests/test_build/test_gen_posres.py @@ -0,0 +1,100 @@ +""" +build/gen_posres のユニットテスト + +sample.top を使って位置拘束 (.itp) の生成と topology への挿入をテストする。 +""" + +import shutil +import types + + +class TestGenPosresRun: + def _make_args(self, top_path, output_prefix, selection="name CA"): + return types.SimpleNamespace( + topology=str(top_path), + selection=selection, + output_prefix=str(output_prefix), + ) + + def test_itp_file_created_for_protein(self, sample_top_path, tmp_path): + """Protein モジュールの posres .itp ファイルが生成されること""" + from src.build.gen_posres import run + + # topology をコピー(in-place で書き換えられるため) + top_copy = tmp_path / "test.top" + shutil.copy(str(sample_top_path), str(top_copy)) + + prefix = tmp_path / "posres" + run(self._make_args(top_copy, prefix, selection="name CA")) + + itp = tmp_path / "posres_Protein.itp" + assert itp.exists() + + def test_itp_contains_ifdef_block(self, sample_top_path, tmp_path): + """生成された .itp が #ifdef / #endif ブロックを持つこと""" + from src.build.gen_posres import run + + top_copy = tmp_path / "test.top" + shutil.copy(str(sample_top_path), str(top_copy)) + + prefix = tmp_path / "posres" + run(self._make_args(top_copy, prefix, selection="name CA")) + + itp = tmp_path / "posres_Protein.itp" + content = itp.read_text() + # output_prefix がフルパスになるため #ifdef の名前はパス依存 + # 形式的な #ifdef / #endif の存在を確認する + assert "#ifdef" in content + assert "#endif" in content + assert "[ position_restraints ]" in content + + def test_itp_contains_selected_atoms(self, sample_top_path, tmp_path): + """CA 原子のインデックスが .itp に含まれること(Protein に CA は 2 個)""" + from src.build.gen_posres import run + + top_copy = tmp_path / "test.top" + shutil.copy(str(sample_top_path), str(top_copy)) + + prefix = tmp_path / "posres" + run(self._make_args(top_copy, prefix, selection="name CA")) + + itp = tmp_path / "posres_Protein.itp" + content = itp.read_text() + + # ALA-CA (index 2) と GLY-CA (index 7) が含まれること + lines = [ + line + for line in content.splitlines() + if line.strip() + and not line.startswith(";") + and not line.startswith("#") + and line.strip()[0].isdigit() + ] + assert len(lines) == 2 + + def test_topology_updated_with_include(self, sample_top_path, tmp_path): + """実行後の topology ファイルに #include 行が追加されること""" + from src.build.gen_posres import run + + top_copy = tmp_path / "test.top" + shutil.copy(str(sample_top_path), str(top_copy)) + + prefix = tmp_path / "posres" + run(self._make_args(top_copy, prefix, selection="name CA")) + + updated = top_copy.read_text() + assert "#include" in updated + assert "posres_Protein.itp" in updated + + def test_no_sol_itp_when_sol_not_selected(self, sample_top_path, tmp_path): + """SOL に CA 原子がないため posres_SOL.itp は生成されないこと""" + from src.build.gen_posres import run + + top_copy = tmp_path / "test.top" + shutil.copy(str(sample_top_path), str(top_copy)) + + prefix = tmp_path / "posres" + run(self._make_args(top_copy, prefix, selection="name CA")) + + sol_itp = tmp_path / "posres_SOL.itp" + assert not sol_itp.exists() diff --git a/tests/test_build/test_gen_temperatures.py b/tests/test_build/test_gen_temperatures.py new file mode 100644 index 0000000..7a1aded --- /dev/null +++ b/tests/test_build/test_gen_temperatures.py @@ -0,0 +1,94 @@ +""" +gen_temperatures のユニットテスト + +純粋計算関数(calc_mu)とバリデーションロジック(run)をテストする。 +""" + +import types + +import pytest + +from src.build.gen_temperatures import A0, A1, B0, B1, calc_mu, run + + +class TestCalcMu: + def test_known_value(self): + """calc_mu の数値が手計算と一致すること""" + nw, np_val, temp, fener = 100, 50, 300.0, 0.0 + expected = (A0 + A1 * temp) * nw + (B0 + B1 * temp) * np_val - temp * fener + assert calc_mu(nw, np_val, temp, fener) == pytest.approx(expected) + + def test_zero_system(self): + """nw=0, np=0 のとき 0 になること""" + result = calc_mu(0, 0, 300.0, 0.0) + assert result == 0.0 + + def test_temperature_effect(self): + """温度が高いほど calc_mu の値が変化すること(単調ではないが変わること)""" + mu1 = calc_mu(1000, 500, 300.0, 0.0) + mu2 = calc_mu(1000, 500, 400.0, 0.0) + assert mu1 != mu2 + + +class TestGenTemperaturesRun: + def _base_args(self, **kwargs): + defaults = dict( + pdes=0.2, + tlow=300.0, + thigh=350.0, + nw=1000, + np=500, + tol=1e-3, + pc=0, + wc=3, + hff=0, + vs=0, + alg=0, + ) + defaults.update(kwargs) + return types.SimpleNamespace(**defaults) + + def test_generates_temperature_ladder(self, capsys): + """正常な入力で温度ラダーが出力されること""" + args = self._base_args() + run(args) + captured = capsys.readouterr() + assert "Temperature" in captured.out + assert "300.00" in captured.out + + def test_first_temp_equals_tlow(self, capsys): + """最初の温度が tlow であること""" + args = self._base_args(tlow=310.0, thigh=360.0) + run(args) + captured = capsys.readouterr() + assert "310.00" in captured.out + + def test_invalid_pdes_raises(self): + """0〜1 の範囲外の pdes は ValueError を送出すること""" + args = self._base_args(pdes=1.5) + with pytest.raises(ValueError, match="Pdes"): + run(args) + + def test_thigh_must_be_greater_than_tlow(self): + """thigh <= tlow のとき ValueError を送出すること""" + args = self._base_args(tlow=350.0, thigh=300.0) + with pytest.raises(ValueError): + run(args) + + def test_tlow_must_be_positive(self): + """tlow が 0 以下のとき ValueError を送出すること""" + args = self._base_args(tlow=0.0) + with pytest.raises(ValueError): + run(args) + + def test_zero_protein_atoms_raises(self): + """np=0 のとき ValueError を送出すること""" + args = self._base_args(np=0) + with pytest.raises(ValueError, match="protein atoms"): + run(args) + + def test_nvt_not_supported(self): + """alg=1(NVT)は未対応なので ValueError を送出すること""" + args = self._base_args(alg=1) + with pytest.raises(ValueError, match="constant volume"): + run(args) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..5f554bb --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,87 @@ +""" +CLI のサブコマンド登録テスト + +全サブコマンドが argparse に正常に登録されることを確認する。 +外部ツール(PyMOL, Gromacs 等)の呼び出しは発生しない。 +""" + +import pytest + + +def test_cli_importable(): + """cli モジュールが import エラーなくロードできること""" + from src.cli import cli + + assert callable(cli) + + +def test_all_subcommands_registered(): + """ + 全サブコマンドが argparse に登録されており、 + --help が正常に動作すること(SystemExit(0) で終了) + """ + from src.cli import cli + + with pytest.raises(SystemExit) as exc_info: + # --help は SystemExit(0) を発生させる + import sys + + sys.argv = ["mdtbx", "--help"] + cli() + assert exc_info.value.code == 0 + + +@pytest.mark.parametrize( + "subcmd", + [ + "addace", + "addnme", + "add_ndx", + "mv_crds_mol2", + "gen_am1bcc", + "gen_resp", + "gen_modres_am1bcc", + "gen_modres_resp", + "gen_posres", + "gen_distres", + "modeling_cf", + "find_bond", + "mod_mdp", + "convert", + "calc_ion_conc", + "centering_gro", + "amb2gro", + "trjcat", + "fit", + "pacs_trjcat", + "rmfile", + "extract_ave_str", + "extract_str", + "show_mdtraj", + "show_npy", + "print_perf", + "shell_hook", + "partial_tempering", + "gen_temperatures", + "cmd", + "build_solution", + "comdist", + "comvec", + "mindist", + "rmsd", + "rmsf", + "xyz", + "pca", + "densmap", + ], +) +def test_subcommand_help(subcmd): + """各サブコマンドの --help が正常終了すること""" + import sys + + from src.cli import cli + + sys.argv = ["mdtbx", subcmd, "--help"] + with pytest.raises(SystemExit) as exc_info: + cli() + assert exc_info.value.code == 0 diff --git a/tests/test_cv/__init__.py b/tests/test_cv/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_cv/test_comdist.py b/tests/test_cv/test_comdist.py new file mode 100644 index 0000000..312a4bc --- /dev/null +++ b/tests/test_cv/test_comdist.py @@ -0,0 +1,62 @@ +""" +cv/comdist のユニットテスト + +合成軌跡を使って重心間距離(COM distance)の計算を検証する。 +""" + +import types + +import numpy as np + + +class TestComdistRun: + def _make_args(self, traj_files, output, sel1="resid 0", sel2="resid 1"): + return types.SimpleNamespace( + topology=traj_files["pdb"], + trajectory=traj_files["xtc"], + selection1=sel1, + selection2=sel2, + output=str(output), + gmx=False, + index=None, + ) + + def test_output_file_created(self, trajectory_files, tmp_path): + """.npy ファイルが生成されること""" + from src.cv.comdist import run + + out = tmp_path / "comdist.npy" + run(self._make_args(trajectory_files, out)) + assert out.exists() + + def test_output_shape(self, trajectory_files, tmp_path): + """出力配列の長さがフレーム数と一致すること""" + from src.cv.comdist import run + + out = tmp_path / "comdist.npy" + run(self._make_args(trajectory_files, out)) + + dist = np.load(str(out)) + n_frames = trajectory_files["traj"].n_frames + assert dist.shape == (n_frames,) + + def test_output_nonnegative(self, trajectory_files, tmp_path): + """距離は常に非負であること""" + from src.cv.comdist import run + + out = tmp_path / "comdist.npy" + run(self._make_args(trajectory_files, out)) + + dist = np.load(str(out)) + assert np.all(dist >= 0) + + def test_same_selection_gives_zero(self, trajectory_files, tmp_path): + """同じ原子群の重心間距離は 0 になること""" + from src.cv.comdist import run + + out = tmp_path / "comdist_zero.npy" + args = self._make_args(trajectory_files, out, sel1="resid 0", sel2="resid 0") + run(args) + + dist = np.load(str(out)) + assert np.allclose(dist, 0.0, atol=1e-6) diff --git a/tests/test_cv/test_comvec.py b/tests/test_cv/test_comvec.py new file mode 100644 index 0000000..e057d0b --- /dev/null +++ b/tests/test_cv/test_comvec.py @@ -0,0 +1,64 @@ +""" +cv/comvec のユニットテスト + +重心ベクトル(COM vector)の計算を合成軌跡で検証する。 +""" + +import types + +import numpy as np + + +class TestComvecRun: + def _make_args(self, traj_files, output, sel1="resid 0", sel2="resid 1"): + return types.SimpleNamespace( + topology=traj_files["pdb"], + trajectory=traj_files["xtc"], + selection1=sel1, + selection2=sel2, + output=str(output), + gmx=False, + index=None, + ) + + def test_output_file_created(self, trajectory_files, tmp_path): + """.npy ファイルが生成されること""" + from src.cv.comvec import run + + out = tmp_path / "comvec.npy" + run(self._make_args(trajectory_files, out)) + assert out.exists() + + def test_output_shape(self, trajectory_files, tmp_path): + """出力配列の形状が (n_frames, 3) であること""" + from src.cv.comvec import run + + out = tmp_path / "comvec_shape.npy" + run(self._make_args(trajectory_files, out)) + vec = np.load(str(out)) + + n_frames = trajectory_files["traj"].n_frames + assert vec.shape == (n_frames, 3) + + def test_same_selection_gives_zero_vector(self, trajectory_files, tmp_path): + """同じ原子群のベクトルは零ベクトルになること""" + from src.cv.comvec import run + + out = tmp_path / "comvec_zero.npy" + args = self._make_args(trajectory_files, out, sel1="resid 0", sel2="resid 0") + run(args) + vec = np.load(str(out)) + assert np.allclose(vec, 0.0, atol=1e-6) + + def test_antisymmetry(self, trajectory_files, tmp_path): + """sel1/sel2 を入れ替えると符号が反転すること""" + from src.cv.comvec import run + + out1 = tmp_path / "comvec_ab.npy" + out2 = tmp_path / "comvec_ba.npy" + run(self._make_args(trajectory_files, out1, sel1="resid 0", sel2="resid 1")) + run(self._make_args(trajectory_files, out2, sel1="resid 1", sel2="resid 0")) + + vec_ab = np.load(str(out1)) + vec_ba = np.load(str(out2)) + assert np.allclose(vec_ab, -vec_ba, atol=1e-6) diff --git a/tests/test_cv/test_densmap.py b/tests/test_cv/test_densmap.py new file mode 100644 index 0000000..71472e6 --- /dev/null +++ b/tests/test_cv/test_densmap.py @@ -0,0 +1,82 @@ +""" +cv/densmap のユニットテスト + +MDtraj を使った 2D 密度マップ計算をテストする(gmx=False パス)。 +""" + +import types + +import numpy as np +import pytest + + +class TestDensmapRun: + def _make_args(self, traj_files, output, selection="all", axis="xy", bins=10): + return types.SimpleNamespace( + topology=traj_files["pdb"], + trajectory=traj_files["xtc"], + selection=selection, + output=str(output), + bins=bins, + axis=axis, + gmx=False, + index=None, + ) + + def test_output_file_created(self, trajectory_files, tmp_path): + """.npy ファイルが生成されること""" + from src.cv.densmap import run + + out = tmp_path / "densmap.npy" + run(self._make_args(trajectory_files, out)) + assert out.exists() + + def test_output_is_nonempty(self, trajectory_files, tmp_path): + """出力ファイルが空でないこと""" + from src.cv.densmap import run + + out = tmp_path / "densmap_size.npy" + run(self._make_args(trajectory_files, out)) + assert out.stat().st_size > 0 + + @pytest.mark.parametrize("axis", ["xy", "xz", "yz"]) + def test_different_axes(self, trajectory_files, tmp_path, axis): + """xy / xz / yz の各投影面でファイルが生成されること""" + from src.cv.densmap import run + + out = tmp_path / f"densmap_{axis}.npy" + run(self._make_args(trajectory_files, out, axis=axis)) + assert out.exists() + + def test_empty_selection_does_not_create_file(self, trajectory_files, tmp_path): + """原子が選択されない場合はファイルを生成しないこと""" + from src.cv.densmap import run + + out = tmp_path / "densmap_empty.npy" + run(self._make_args(trajectory_files, out, selection="name XXXX")) + assert not out.exists() + + def test_histogram_shape(self, trajectory_files, tmp_path): + """np.histogram2d の結果と整合すること(ゴールデンパス検証)""" + import mdtraj as md + from src.cv.densmap import _AXIS_MAP, run + + bins = 8 + axis = "xy" + out = tmp_path / "densmap_golden.npy" + run(self._make_args(trajectory_files, out, bins=bins, axis=axis)) + + # 内部と同じ計算で counts を再現し、合計を比較する + traj = md.load(trajectory_files["xtc"], top=trajectory_files["pdb"]) + atom_indices = traj.topology.select("all") + xyz = traj.xyz[:, atom_indices, :] + ax0, ax1 = _AXIS_MAP[axis] + pos0 = xyz[:, :, ax0].ravel() + pos1 = xyz[:, :, ax1].ravel() + counts_ref, _, _ = np.histogram2d(pos0, pos1, bins=bins) + + # ファイルが存在し、総カウントが一致すること + assert out.exists() + total_ref = int(counts_ref.sum()) + expected = traj.n_frames * traj.n_atoms + assert total_ref == expected diff --git a/tests/test_cv/test_mindist.py b/tests/test_cv/test_mindist.py new file mode 100644 index 0000000..4d0487b --- /dev/null +++ b/tests/test_cv/test_mindist.py @@ -0,0 +1,61 @@ +""" +cv/mindist のユニットテスト + +原子間最短距離の計算を合成軌跡で検証する。 +""" + +import types + +import numpy as np + + +class TestMindistRun: + def _make_args(self, traj_files, output, sel1="resid 0", sel2="resid 1"): + return types.SimpleNamespace( + topology=traj_files["pdb"], + trajectory=traj_files["xtc"], + selection1=sel1, + selection2=sel2, + output=str(output), + ) + + def test_output_file_created(self, trajectory_files, tmp_path): + """.npy ファイルが生成されること""" + from src.cv.mindist import run + + out = tmp_path / "mindist.npy" + run(self._make_args(trajectory_files, out)) + assert out.exists() + + def test_output_shape(self, trajectory_files, tmp_path): + """出力配列の長さがフレーム数と一致すること""" + from src.cv.mindist import run + + out = tmp_path / "mindist_shape.npy" + run(self._make_args(trajectory_files, out)) + dist = np.load(str(out)) + + n_frames = trajectory_files["traj"].n_frames + assert dist.shape == (n_frames,) + + def test_output_nonnegative(self, trajectory_files, tmp_path): + """距離は常に非負であること""" + from src.cv.mindist import run + + out = tmp_path / "mindist_nn.npy" + run(self._make_args(trajectory_files, out)) + dist = np.load(str(out)) + assert np.all(dist >= 0) + + def test_mindist_consistent_across_runs(self, trajectory_files, tmp_path): + """同じ入力で 2 回実行しても同一の結果が得られること(決定論的)""" + from src.cv.mindist import run + + out1 = tmp_path / "mindist_run1.npy" + out2 = tmp_path / "mindist_run2.npy" + run(self._make_args(trajectory_files, out1)) + run(self._make_args(trajectory_files, out2)) + + dist1 = np.load(str(out1)) + dist2 = np.load(str(out2)) + assert np.allclose(dist1, dist2) diff --git a/tests/test_cv/test_pca.py b/tests/test_cv/test_pca.py new file mode 100644 index 0000000..01c4122 --- /dev/null +++ b/tests/test_cv/test_pca.py @@ -0,0 +1,80 @@ +""" +cv/pca のユニットテスト + +MDtraj + scikit-learn を使った PCA をテストする(gmx=False パス)。 +""" + +import types + +import numpy as np + + +class TestPcaRun: + def _make_args(self, traj_files, output, n_components=2): + return types.SimpleNamespace( + topology=traj_files["pdb"], + trajectory=traj_files["xtc"], + reference=traj_files["pdb"], + selection_cal_trj="all", + selection_cal_ref="all", + selection_fit_trj="all", + selection_fit_ref="all", + output=str(output), + gmx=False, + n_components=n_components, + index=None, + ) + + def test_output_file_created(self, trajectory_files, tmp_path): + """.npy ファイルが生成されること""" + from src.cv.pca import run + + out = tmp_path / "pca.npy" + run(self._make_args(trajectory_files, out)) + assert out.exists() + + def test_output_shape(self, trajectory_files, tmp_path): + """出力形状が (n_frames, n_components) であること""" + from src.cv.pca import run + + n_components = 3 + out = tmp_path / "pca_shape.npy" + run(self._make_args(trajectory_files, out, n_components=n_components)) + pc = np.load(str(out)) + + n_frames = trajectory_files["traj"].n_frames + assert pc.shape == (n_frames, n_components) + + def test_first_pc_has_max_variance(self, trajectory_files, tmp_path): + """PC1 の分散が PC2 以上であること(PCA の定義)""" + from src.cv.pca import run + + out = tmp_path / "pca_var.npy" + run(self._make_args(trajectory_files, out, n_components=2)) + pc = np.load(str(out)) + + var1 = np.var(pc[:, 0]) + var2 = np.var(pc[:, 1]) + assert var1 >= var2 + + def test_subset_selection(self, trajectory_files, tmp_path): + """部分選択でも正常に計算できること""" + from src.cv.pca import run + + out = tmp_path / "pca_subset.npy" + args = types.SimpleNamespace( + topology=trajectory_files["pdb"], + trajectory=trajectory_files["xtc"], + reference=trajectory_files["pdb"], + selection_cal_trj="resid 0", + selection_cal_ref="resid 0", + selection_fit_trj="resid 0", + selection_fit_ref="resid 0", + output=str(out), + gmx=False, + n_components=2, + index=None, + ) + run(args) + pc = np.load(str(out)) + assert pc.shape[0] == trajectory_files["traj"].n_frames diff --git a/tests/test_cv/test_rmsd.py b/tests/test_cv/test_rmsd.py new file mode 100644 index 0000000..0e2e381 --- /dev/null +++ b/tests/test_cv/test_rmsd.py @@ -0,0 +1,83 @@ +""" +cv/rmsd のユニットテスト + +合成軌跡データを使って RMSD 計算の正確性を検証する。 +""" + +import types + +import numpy as np +import pytest + + +class TestRmsdRun: + def _make_args(self, traj_files, output, selection="all"): + return types.SimpleNamespace( + topology=traj_files["pdb"], + trajectory=traj_files["xtc"], + reference=traj_files["pdb"], + selection_fit_trj=selection, + selection_fit_ref=selection, + selection_cal_trj=selection, + selection_cal_ref=selection, + output=str(output), + ) + + def test_rmsd_output_file_created(self, trajectory_files, tmp_path): + """run() が .npy ファイルを生成すること""" + from src.cv.rmsd import run + + out = tmp_path / "rmsd.npy" + args = self._make_args(trajectory_files, out) + run(args) + assert out.exists() + + def test_rmsd_shape(self, trajectory_files, tmp_path): + """RMSD 配列の長さが軌跡のフレーム数と一致すること""" + from src.cv.rmsd import run + + out = tmp_path / "rmsd.npy" + args = self._make_args(trajectory_files, out) + run(args) + + rmsd = np.load(str(out)) + n_frames = trajectory_files["traj"].n_frames + assert rmsd.shape == (n_frames,) + + def test_rmsd_nonnegative(self, trajectory_files, tmp_path): + """RMSD は常に非負であること""" + from src.cv.rmsd import run + + out = tmp_path / "rmsd.npy" + args = self._make_args(trajectory_files, out) + run(args) + + rmsd = np.load(str(out)) + assert np.all(rmsd >= 0) + + def test_rmsd_first_frame_near_zero(self, trajectory_files, tmp_path): + """ + reference は frame 0 の PDB であるため、 + frame 0 の RMSD は fitting 後にほぼ 0 になること。 + """ + from src.cv.rmsd import run + + out = tmp_path / "rmsd.npy" + args = self._make_args(trajectory_files, out) + run(args) + + rmsd = np.load(str(out)) + # XTC は float32 のため PDB (float64) との往復で精度損失がある + assert rmsd[0] == pytest.approx(0.0, abs=1e-3) + + def test_rmsd_other_frames_nonzero(self, trajectory_files, tmp_path): + """ランダムな座標を持つ他のフレームの RMSD は 0 より大きいこと""" + from src.cv.rmsd import run + + out = tmp_path / "rmsd.npy" + args = self._make_args(trajectory_files, out) + run(args) + + rmsd = np.load(str(out)) + # frame 0 以外のどこかが 0 より大きいことを確認 + assert np.any(rmsd[1:] > 0) diff --git a/tests/test_cv/test_rmsf.py b/tests/test_cv/test_rmsf.py new file mode 100644 index 0000000..5927b4b --- /dev/null +++ b/tests/test_cv/test_rmsf.py @@ -0,0 +1,96 @@ +""" +cv/rmsf のユニットテスト + +MDtraj を使った RMSF 計算をテストする(gmx=False パス)。 +""" + +import types + +import numpy as np + + +class TestRmsfRun: + def _make_args(self, traj_files, output, selection="all"): + return types.SimpleNamespace( + topology=traj_files["pdb"], + trajectory=traj_files["xtc"], + selection=selection, + output=str(output), + gmx=False, + resolution="atom", + ) + + def test_output_file_created(self, trajectory_files, tmp_path): + """.npy ファイルが生成されること""" + from src.cv.rmsf import run + + out = tmp_path / "rmsf.npy" + run(self._make_args(trajectory_files, out)) + assert out.exists() + + def test_output_shape_all_atoms(self, trajectory_files, tmp_path): + """全原子選択時の出力長が原子数と一致すること""" + from src.cv.rmsf import run + + out = tmp_path / "rmsf_all.npy" + run(self._make_args(trajectory_files, out)) + rmsf = np.load(str(out)) + + n_atoms = trajectory_files["traj"].n_atoms + assert rmsf.shape == (n_atoms,) + + def test_output_shape_subset(self, trajectory_files, tmp_path): + """部分選択時の出力長が選択原子数と一致すること""" + from src.cv.rmsf import run + + out = tmp_path / "rmsf_subset.npy" + run(self._make_args(trajectory_files, out, selection="resid 0")) + rmsf = np.load(str(out)) + + n_selected = len(trajectory_files["traj"].topology.select("resid 0")) + assert rmsf.shape == (n_selected,) + + def test_output_nonnegative(self, trajectory_files, tmp_path): + """RMSF は常に非負であること""" + from src.cv.rmsf import run + + out = tmp_path / "rmsf_nn.npy" + run(self._make_args(trajectory_files, out)) + rmsf = np.load(str(out)) + assert np.all(rmsf >= 0) + + def test_static_trajectory_gives_zero_rmsf(self, tmp_path_factory): + """全フレームが同一座標の軌跡では RMSF が 0 になること""" + import mdtraj as md + + from src.cv.rmsf import run + + tmp = tmp_path_factory.mktemp("static") + + # 静止軌跡を作成 + top = md.Topology() + chain = top.add_chain() + res = top.add_residue("ALA", chain) + top.add_atom("CA", md.element.carbon, res) + + xyz = np.zeros((5, 1, 3)) + traj = md.Trajectory(xyz, top) + + pdb_path = str(tmp / "static.pdb") + xtc_path = str(tmp / "static.xtc") + traj[0].save_pdb(pdb_path) + traj.save_xtc(xtc_path) + + out = tmp / "rmsf_static.npy" + run( + types.SimpleNamespace( + topology=pdb_path, + trajectory=xtc_path, + selection="all", + output=str(out), + gmx=False, + resolution="atom", + ) + ) + rmsf = np.load(str(out)) + assert np.allclose(rmsf, 0.0, atol=1e-6) diff --git a/tests/test_cv/test_xyz.py b/tests/test_cv/test_xyz.py new file mode 100644 index 0000000..f9b42bc --- /dev/null +++ b/tests/test_cv/test_xyz.py @@ -0,0 +1,60 @@ +""" +cv/xyz のユニットテスト + +MDtraj を使った XYZ 座標抽出をテストする(gmx=False パス)。 +""" + +import types + +import numpy as np + + +class TestXyzRun: + def _make_args(self, traj_files, output, selection="all"): + return types.SimpleNamespace( + topology=traj_files["pdb"], + trajectory=traj_files["xtc"], + selection=selection, + output=str(output), + ) + + def test_output_file_created(self, trajectory_files, tmp_path): + """.npy ファイルが生成されること""" + from src.cv.xyz import run + + out = tmp_path / "xyz.npy" + run(self._make_args(trajectory_files, out)) + assert out.exists() + + def test_output_shape_all_atoms(self, trajectory_files, tmp_path): + """全原子選択時の形状が (n_frames, n_atoms, 3) であること""" + from src.cv.xyz import run + + out = tmp_path / "xyz_all.npy" + run(self._make_args(trajectory_files, out)) + xyz = np.load(str(out)) + + n_frames = trajectory_files["traj"].n_frames + n_atoms = trajectory_files["traj"].n_atoms + assert xyz.shape == (n_frames, n_atoms, 3) + + def test_output_shape_subset(self, trajectory_files, tmp_path): + """部分選択時の形状が (n_frames, n_selected, 3) であること""" + from src.cv.xyz import run + + out = tmp_path / "xyz_subset.npy" + run(self._make_args(trajectory_files, out, selection="resid 0")) + xyz = np.load(str(out)) + + n_frames = trajectory_files["traj"].n_frames + n_selected = len(trajectory_files["traj"].topology.select("resid 0")) + assert xyz.shape == (n_frames, n_selected, 3) + + def test_coordinates_are_finite(self, trajectory_files, tmp_path): + """出力座標に NaN / Inf が含まれないこと""" + from src.cv.xyz import run + + out = tmp_path / "xyz_finite.npy" + run(self._make_args(trajectory_files, out)) + xyz = np.load(str(out)) + assert np.all(np.isfinite(xyz)) diff --git a/tests/test_trajectory/__init__.py b/tests/test_trajectory/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_trajectory/test_fit.py b/tests/test_trajectory/test_fit.py new file mode 100644 index 0000000..6a92d0a --- /dev/null +++ b/tests/test_trajectory/test_fit.py @@ -0,0 +1,55 @@ +""" +trajectory/fit のユニットテスト + +MDtraj を使った軌跡フィッティングをテストする(gmx=False パス)。 +""" + +import types + + +class TestFitRun: + def _make_args(self, traj_files, output, selection="all"): + return types.SimpleNamespace( + file=traj_files["xtc"], + topology=traj_files["pdb"], + output=str(output), + selection=selection, + gmx=False, + pbc="mol", + index="index.ndx", + ) + + def test_output_file_created(self, trajectory_files, tmp_path): + """run() がフィット済み軌跡ファイルを生成すること""" + + from src.trajectory.fit import run + + out = tmp_path / "fitted.xtc" + run(self._make_args(trajectory_files, out)) + assert out.exists() + + def test_output_has_same_frames(self, trajectory_files, tmp_path): + """フィット後の軌跡のフレーム数が元と同じであること""" + import mdtraj as md + + from src.trajectory.fit import run + + out = tmp_path / "fitted.xtc" + run(self._make_args(trajectory_files, out)) + + fitted = md.load(str(out), top=trajectory_files["pdb"]) + original_n_frames = trajectory_files["traj"].n_frames + assert fitted.n_frames == original_n_frames + + def test_output_has_same_atoms(self, trajectory_files, tmp_path): + """フィット後の軌跡の原子数が元と同じであること""" + import mdtraj as md + + from src.trajectory.fit import run + + out = tmp_path / "fitted.xtc" + run(self._make_args(trajectory_files, out)) + + fitted = md.load(str(out), top=trajectory_files["pdb"]) + original_n_atoms = trajectory_files["traj"].n_atoms + assert fitted.n_atoms == original_n_atoms diff --git a/tests/test_trajectory/test_print_perf.py b/tests/test_trajectory/test_print_perf.py new file mode 100644 index 0000000..19c0bb5 --- /dev/null +++ b/tests/test_trajectory/test_print_perf.py @@ -0,0 +1,86 @@ +""" +trajectory/print_perf のユニットテスト + +parse_log_file() の純粋なログパースロジックをテストする。 +外部ツール・gmx に依存しない。 +""" + +import pytest + +from src.trajectory.print_perf import parse_log_file + +# GROMACS ログの最小限サンプル +_SAMPLE_LOG = """\ + GROMACS version: 2023.3 + Executable: /usr/local/bin/gmx +Hardware detected on host myhost01: +GPU info: + Number of GPUs detected: 2 + GPU 0: NVIDIA A100 +CPU info: + Vendor: Intel + Model name: Xeon Gold 6338 +Command line: + gmx mdrun -deffnm prd -ntmpi 4 -ntomp 8 + + Performance: 123.45 0.194 +""" + + +class TestParseLogFile: + def test_parses_performance(self, tmp_path): + """`Performance:` 行から ns/day 値を正しく取得すること""" + log = tmp_path / "prd.log" + log.write_text(_SAMPLE_LOG) + data = parse_log_file(log) + assert data["performance"] == pytest.approx(123.45) + + def test_parses_version(self, tmp_path): + """`GROMACS version:` から版数文字列を取得すること""" + log = tmp_path / "prd.log" + log.write_text(_SAMPLE_LOG) + data = parse_log_file(log) + assert data["version"] == "2023.3" + + def test_parses_executable(self, tmp_path): + """`Executable:` からパスを取得すること""" + log = tmp_path / "prd.log" + log.write_text(_SAMPLE_LOG) + data = parse_log_file(log) + assert data["executable"] == "/usr/local/bin/gmx" + + def test_parses_hostname(self, tmp_path): + """`Hardware detected on host` からホスト名を取得すること""" + log = tmp_path / "prd.log" + log.write_text(_SAMPLE_LOG) + data = parse_log_file(log) + assert data["hostname"] == "myhost01" + + def test_cmd_strips_deffnm(self, tmp_path): + """`-deffnm` オプションが除去されたコマンド文字列を返すこと""" + log = tmp_path / "prd.log" + log.write_text(_SAMPLE_LOG) + data = parse_log_file(log) + assert "-deffnm" not in data["cmd"] + assert "gmx mdrun" in data["cmd"] + + def test_nonexistent_file_returns_none(self, tmp_path): + """存在しないファイルは None を返すこと""" + data = parse_log_file(tmp_path / "nonexistent.log") + assert data is None + + def test_empty_log_has_no_performance(self, tmp_path): + """Performance 行がないログでは performance が None になること""" + log = tmp_path / "empty.log" + log.write_text("Some GROMACS output without performance line\n") + data = parse_log_file(log) + assert data is not None + assert data["performance"] is None + + def test_default_values_for_missing_fields(self, tmp_path): + """フィールドが見つからない場合は 'N/A' がデフォルトであること""" + log = tmp_path / "minimal.log" + log.write_text(" Performance: 50.0 0.5\n") + data = parse_log_file(log) + assert data["version"] == "N/A" + assert data["hostname"] == "N/A" diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_utils/test_atom_selection_parser.py b/tests/test_utils/test_atom_selection_parser.py new file mode 100644 index 0000000..81acab5 --- /dev/null +++ b/tests/test_utils/test_atom_selection_parser.py @@ -0,0 +1,305 @@ +""" +atom_selection_parser のユニットテスト + +SelectionParser / AtomSelector の純粋なパースロジックをテストする。 +外部ツール・ファイル I/O に依存しない。 +""" + +import pytest + +from src.utils.atom_selection_parser import ( + All, + And, + AtomSelector, + Backbone, + Bracket, + Chain, + Index, + Name, + Not, + Or, + Protein, + ResId, + ResName, + SelectionParser, + Sidechain, + Water, + parse_selection, +) + +# --------------------------------------------------------------------------- +# ヘルパー: テスト用の原子辞書 +# --------------------------------------------------------------------------- + +ALA_CA = {"resname": "ALA", "name": "CA", "resid": 1, "index": 1} +ALA_CB = {"resname": "ALA", "name": "CB", "resid": 1, "index": 2} +GLY_N = {"resname": "GLY", "name": "N", "resid": 2, "index": 3} +HOH_O = {"resname": "HOH", "name": "O", "resid": 3, "index": 4} +NA_ION = {"resname": "NA", "name": "NA", "resid": 4, "index": 5} + +# --------------------------------------------------------------------------- +# SelectionNode の基本評価 +# --------------------------------------------------------------------------- + + +class TestAll: + def test_always_true(self): + node = All() + assert node.eval(ALA_CA) + assert node.eval(HOH_O) + assert node.eval({}) + + +class TestProtein: + def test_ala_is_protein(self): + assert Protein().eval(ALA_CA) + + def test_water_is_not_protein(self): + assert not Protein().eval(HOH_O) + + def test_unknown_resname_is_not_protein(self): + assert not Protein().eval({"resname": "LIG", "name": "C1"}) + + +class TestWater: + def test_hoh_is_water(self): + assert Water().eval(HOH_O) + + def test_sol_is_water(self): + assert Water().eval({"resname": "SOL", "name": "OW"}) + + def test_protein_is_not_water(self): + assert not Water().eval(ALA_CA) + + +class TestBackbone: + def test_ca_is_backbone(self): + assert Backbone().eval(ALA_CA) + + def test_cb_is_not_backbone(self): + assert not Backbone().eval(ALA_CB) + + def test_n_is_backbone(self): + assert Backbone().eval({"resname": "ALA", "name": "N"}) + + def test_water_o_is_not_backbone(self): + assert not Backbone().eval(HOH_O) + + +class TestSidechain: + def test_cb_is_sidechain(self): + assert Sidechain().eval(ALA_CB) + + def test_ca_is_not_sidechain(self): + assert not Sidechain().eval(ALA_CA) + + def test_water_is_not_sidechain(self): + assert not Sidechain().eval(HOH_O) + + +class TestResName: + def test_match_single(self): + assert ResName(["ALA"]).eval(ALA_CA) + + def test_match_multiple(self): + assert ResName(["ALA", "GLY"]).eval(GLY_N) + + def test_no_match(self): + assert not ResName(["ALA"]).eval(HOH_O) + + def test_wildcard(self): + assert ResName(["A*"]).eval(ALA_CA) + assert not ResName(["A*"]).eval(GLY_N) + + +class TestResId: + def test_match(self): + assert ResId([1]).eval(ALA_CA) + + def test_range(self): + assert ResId([1, 2, 3]).eval(GLY_N) + + def test_no_match(self): + assert not ResId([5, 6]).eval(ALA_CA) + + +class TestName: + def test_match(self): + assert Name(["CA"]).eval(ALA_CA) + + def test_wildcard_prefix(self): + assert Name(["C*"]).eval(ALA_CA) + assert Name(["C*"]).eval(ALA_CB) + assert not Name(["C*"]).eval(GLY_N) + + def test_wildcard_all(self): + assert Name(["*"]).eval(ALA_CA) + + +class TestIndex: + def test_match(self): + assert Index([1]).eval(ALA_CA) + + def test_no_match(self): + assert not Index([99]).eval(ALA_CA) + + +class TestNot: + def test_inverts_protein(self): + assert Not(Protein()).eval(HOH_O) + assert not Not(Protein()).eval(ALA_CA) + + +class TestAnd: + def test_both_true(self): + assert And([Protein(), Backbone()]).eval(ALA_CA) + + def test_one_false(self): + assert not And([Protein(), Backbone()]).eval(ALA_CB) + + def test_empty_is_true(self): + assert And([]).eval(ALA_CA) + + +class TestOr: + def test_one_true(self): + assert Or([Protein(), Water()]).eval(HOH_O) + assert Or([Protein(), Water()]).eval(ALA_CA) + + def test_both_false(self): + assert not Or([Protein(), Water()]).eval(NA_ION) + + def test_empty_is_false(self): + assert not Or([]).eval(ALA_CA) + + +# --------------------------------------------------------------------------- +# SelectionParser パース結果の検証 +# --------------------------------------------------------------------------- + + +class TestSelectionParser: + def test_protein(self): + result = SelectionParser("protein").parse() + assert isinstance(result, Protein) + + def test_all(self): + result = SelectionParser("all").parse() + assert isinstance(result, All) + + def test_resname_single(self): + result = SelectionParser("resname ALA").parse() + assert isinstance(result, ResName) + assert result.names == ["ALA"] + + def test_resname_multiple(self): + result = SelectionParser("resname ALA GLY").parse() + assert isinstance(result, ResName) + assert result.names == ["ALA", "GLY"] + + def test_resid_single(self): + result = SelectionParser("resid 5").parse() + assert isinstance(result, ResId) + assert result.ids == [5] + + def test_resid_range(self): + result = SelectionParser("resid 1 to 5").parse() + assert isinstance(result, ResId) + assert result.ids == [1, 2, 3, 4, 5] + + def test_and_expression(self): + result = SelectionParser("protein and backbone").parse() + assert isinstance(result, And) + + def test_or_expression(self): + result = SelectionParser("protein or water").parse() + assert isinstance(result, Or) + + def test_not_expression(self): + result = SelectionParser("not water").parse() + assert isinstance(result, Not) + + def test_bracket(self): + result = SelectionParser("(protein and backbone)").parse() + assert isinstance(result, Bracket) + + def test_chain(self): + result = SelectionParser("chain A").parse() + assert isinstance(result, Chain) + + def test_name_wildcard(self): + result = SelectionParser("name C*").parse() + assert isinstance(result, Name) + + def test_index_range(self): + result = SelectionParser("index 1 to 3").parse() + assert isinstance(result, Index) + assert result.indices == [1, 2, 3] + + +# --------------------------------------------------------------------------- +# AtomSelector の end-to-end テスト +# --------------------------------------------------------------------------- + + +class TestAtomSelector: + @pytest.mark.parametrize( + "selection, mol, expected", + [ + ("protein and backbone", ALA_CA, True), + ("protein and backbone", ALA_CB, False), + ("protein and sidechain", ALA_CB, True), + ("protein and sidechain", HOH_O, False), + ("water", HOH_O, True), + ("not water", HOH_O, False), + ("resid 1 to 3 and name CA", ALA_CA, True), + ("resid 1 to 3 and name CA", GLY_N, False), + ("all", {"resname": "XYZ"}, True), + ("name C*", ALA_CA, True), + ("name C*", ALA_CB, True), + ("name C*", GLY_N, False), + ("resname A*", ALA_CA, True), + ("resname A*", HOH_O, False), + ("(resname ALA GLY) or (resname HOH and name O)", ALA_CA, True), + ("(resname ALA GLY) or (resname HOH and name O)", HOH_O, True), + ("(resname ALA GLY) or (resname HOH and name O)", NA_ION, False), + ], + ) + def test_selection_result(self, selection, mol, expected): + selector = AtomSelector(selection) + assert selector.eval(mol) == expected + + def test_double_not(self): + selector = AtomSelector("not not protein") + assert selector.eval(ALA_CA) + + def test_chain_selection(self): + mol_a = {"resname": "ALA", "name": "CA", "chain": "A"} + mol_b = {"resname": "ALA", "name": "CA", "chain": "B"} + assert AtomSelector("chain A and protein").eval(mol_a) + assert not AtomSelector("chain B and protein").eval(mol_a) + assert AtomSelector("chain B and protein").eval(mol_b) + + +# --------------------------------------------------------------------------- +# エラー処理 +# --------------------------------------------------------------------------- + + +class TestParseErrors: + @pytest.mark.parametrize( + "bad_selection", + [ + "resid 10 to", + "resname and", + "( resid 1", + "name ca and", + ], + ) + def test_invalid_selection_raises(self, bad_selection): + with pytest.raises(ValueError): + AtomSelector(bad_selection) + + def test_parse_selection_returns_error_string_on_failure(self): + result = parse_selection("resid 10 to") + assert isinstance(result, str) diff --git a/tests/test_utils/test_mod_mdp.py b/tests/test_utils/test_mod_mdp.py new file mode 100644 index 0000000..5f3e96c --- /dev/null +++ b/tests/test_utils/test_mod_mdp.py @@ -0,0 +1,124 @@ +""" +mod_mdp のユニットテスト + +mod_mdp() 関数(ファイル I/O)と run()(ディレクトリ走査)をテストする。 +tmp_path fixture により、テスト後に一時ファイルは自動削除される。 +""" + +from src.utils.mod_mdp import mod_mdp, run + + +class TestModMdp: + def test_replace_existing_value(self, tmp_path): + """既存キーの値が書き換わること""" + mdp = tmp_path / "md.mdp" + mdp.write_text("nsteps = 1000\n") + + mod_mdp("nsteps", "2000", mdp, ljust=23) + + content = mdp.read_text() + assert "nsteps" in content + assert "2000" in content + assert "1000" not in content + + def test_add_new_key(self, tmp_path): + """存在しないキーが末尾に追加されること""" + mdp = tmp_path / "md.mdp" + mdp.write_text("nsteps = 1000\n") + + mod_mdp("dt", "0.002", mdp, ljust=23) + + content = mdp.read_text() + assert "dt" in content + assert "0.002" in content + # 元の内容は保持される + assert "nsteps" in content + + def test_comment_lines_preserved(self, tmp_path): + """コメント行(; で始まる)は変更されないこと""" + mdp = tmp_path / "md.mdp" + original = "; this is a comment\nnsteps = 1000\n" + mdp.write_text(original) + + mod_mdp("nsteps", "2000", mdp, ljust=23) + + content = mdp.read_text() + assert "; this is a comment" in content + + def test_ljust_controls_padding(self, tmp_path): + """ljust パラメータが新規キーの整形幅を制御すること""" + mdp = tmp_path / "md.mdp" + mdp.write_text("") + + mod_mdp("dt", "0.002", mdp, ljust=30) + + content = mdp.read_text() + # フォーマット: f"{key.ljust(30)} = {value}" → "dt" + 28空白 + " = 0.002" + assert "dt" + " " * 28 + " = 0.002" in content + + def test_multiple_keys(self, tmp_path): + """複数のキーを順番に書き換えられること""" + mdp = tmp_path / "md.mdp" + mdp.write_text("nsteps = 1000\ndt = 0.002\n") + + mod_mdp("nsteps", "5000", mdp, ljust=7) + mod_mdp("dt", "0.001", mdp, ljust=7) + + content = mdp.read_text() + assert "5000" in content + assert "0.001" in content + + +class TestModMdpRun: + def test_run_modifies_all_mdp_files(self, tmp_path): + """run() がディレクトリ内の全 .mdp ファイルを更新すること""" + for name in ["em.mdp", "nvt.mdp", "npt.mdp"]: + (tmp_path / name).write_text("nsteps = 100\n") + + import types + + args = types.SimpleNamespace( + path=str(tmp_path), + target_variable="nsteps", + new_value="50000", + exclude=None, + ljust=23, + ) + run(args) + + for name in ["em.mdp", "nvt.mdp", "npt.mdp"]: + content = (tmp_path / name).read_text() + assert "50000" in content + + def test_run_respects_exclude(self, tmp_path): + """exclude オプションで指定したファイルは変更されないこと""" + (tmp_path / "nvt.mdp").write_text("nsteps = 100\n") + (tmp_path / "npt.mdp").write_text("nsteps = 100\n") + + import types + + args = types.SimpleNamespace( + path=str(tmp_path), + target_variable="nsteps", + new_value="50000", + exclude=["npt"], + ljust=23, + ) + run(args) + + assert "50000" in (tmp_path / "nvt.mdp").read_text() + assert "100" in (tmp_path / "npt.mdp").read_text() # 変更されていない + + def test_run_nonexistent_dir(self, tmp_path): + """存在しないディレクトリを指定した場合でもエラーにならないこと(.mdp ファイルが 0 件)""" + import types + + args = types.SimpleNamespace( + path=str(tmp_path / "nodir"), + target_variable="nsteps", + new_value="1000", + exclude=None, + ljust=23, + ) + # glob が空を返すだけでエラーにならない + run(args) diff --git a/tests/test_utils/test_parse_top.py b/tests/test_utils/test_parse_top.py new file mode 100644 index 0000000..6a38f69 --- /dev/null +++ b/tests/test_utils/test_parse_top.py @@ -0,0 +1,79 @@ +""" +parse_top (GromacsTopologyParser) のユニットテスト + +tests/fixtures/sample.top を使ってトポロジーパーサーをテストする。 +""" + +import pytest + +from src.utils.parse_top import GromacsTopologyParser + + +@pytest.fixture(scope="module") +def parser(sample_top_path): + return GromacsTopologyParser(str(sample_top_path)) + + +class TestGromacsTopologyParser: + def test_get_all_moleculetypes(self, parser): + """モジュール名が正しく取得できること""" + moltypes = parser.get_all_moleculetypes() + assert "Protein" in moltypes + assert "SOL" in moltypes + + def test_moleculetype_order(self, parser): + """定義順が保持されること""" + moltypes = parser.get_all_moleculetypes() + assert moltypes.index("Protein") < moltypes.index("SOL") + + def test_get_atoms_in_protein(self, parser): + """Protein の原子リストが取得できること""" + atoms = parser.get_atoms_in("Protein") + assert len(atoms) > 0 + + def test_protein_atom_fields(self, parser): + """原子辞書に必須フィールドが含まれること""" + atoms = parser.get_atoms_in("Protein") + first = atoms[0] + assert "atom_type" in first + assert "index" in first + assert "resid" in first + assert "resname" in first + assert "name" in first + + def test_protein_residue_names(self, parser): + """ALA と GLY の残基が含まれること(fixture の内容に対応)""" + atoms = parser.get_atoms_in("Protein") + resnames = {a["resname"] for a in atoms} + assert "ALA" in resnames + assert "GLY" in resnames + + def test_get_atoms_in_sol(self, parser): + """SOL の原子リストが取得できること""" + atoms = parser.get_atoms_in("SOL") + assert len(atoms) == 3 # OW, HW1, HW2 + + def test_sol_atom_names(self, parser): + """SOL の原子名が正しいこと""" + atoms = parser.get_atoms_in("SOL") + names = [a["name"] for a in atoms] + assert "OW" in names + assert "HW1" in names + assert "HW2" in names + + def test_get_insert_linenumber(self, parser): + """挿入行番号が整数で返ること""" + lineno = parser.get_insert_linenumber_in("Protein") + assert isinstance(lineno, int) + assert lineno > 0 + + def test_atom_index_sequential(self, parser): + """原子のインデックスが連続していること""" + atoms = parser.get_atoms_in("Protein") + indices = [a["index"] for a in atoms] + assert indices == list(range(1, len(atoms) + 1)) + + def test_invalid_moleculetype_raises(self, parser): + """存在しないモジュール名は KeyError になること""" + with pytest.raises(KeyError): + parser.get_atoms_in("NONEXISTENT") diff --git a/tests/test_utils/test_partial_tempering.py b/tests/test_utils/test_partial_tempering.py new file mode 100644 index 0000000..d986843 --- /dev/null +++ b/tests/test_utils/test_partial_tempering.py @@ -0,0 +1,82 @@ +""" +utils/partial_tempering のユニットテスト + +sample.top を使って atom_type へのアンダースコア付与をテストする。 +""" + +import types + + +class TestPartialTemperingRun: + def _make_args(self, topology_path, output_path, selection): + return types.SimpleNamespace( + topology=str(topology_path), + selection=selection, + output=str(output_path), + ) + + def test_output_file_created(self, sample_top_path, tmp_path): + """出力ファイルが生成されること""" + from src.utils.partial_tempering import run + + out = tmp_path / "output.top" + run(self._make_args(sample_top_path, out, selection="resname ALA")) + assert out.exists() + + def test_selected_atoms_get_underscore(self, sample_top_path, tmp_path): + """選択された ALA 原子の atom_type に _ が付加されること""" + from src.utils.partial_tempering import run + + out = tmp_path / "output.top" + run(self._make_args(sample_top_path, out, selection="resname ALA")) + + content = out.read_text() + # ALA の atom_type CT が CT_ に変更されているはず + assert "CT_" in content + + def test_unselected_atoms_unchanged(self, sample_top_path, tmp_path): + """選択されていない GLY 原子の atom_type は変更されないこと""" + from src.utils.partial_tempering import run + + out = tmp_path / "output.top" + # ALA のみを選択 + run(self._make_args(sample_top_path, out, selection="resname ALA")) + + content = out.read_text() + # GLY の行(residue 2)は CT のまま残る + lines = content.splitlines() + gly_lines = [ + line + for line in lines + if "GLY" in line and line.strip() and line.strip()[0].isdigit() + ] + for line in gly_lines: + tokens = line.split() + if len(tokens) >= 2: + atom_type = tokens[1] + assert not atom_type.endswith("_"), ( + f"GLY atom_type should not have underscore: {line}" + ) + + def test_original_file_not_modified(self, sample_top_path, tmp_path): + """入力ファイルが変更されないこと""" + from src.utils.partial_tempering import run + + original_content = sample_top_path.read_text() + out = tmp_path / "output.top" + run(self._make_args(sample_top_path, out, selection="resname ALA")) + + assert sample_top_path.read_text() == original_content + + def test_no_match_selection_produces_unchanged_output( + self, sample_top_path, tmp_path + ): + """マッチしない選択ではアンダースコアが付かないこと""" + from src.utils.partial_tempering import run + + out = tmp_path / "output_nomatch.top" + run(self._make_args(sample_top_path, out, selection="resname LIG")) + + output = out.read_text() + # CT_ は含まれないはず + assert "CT_" not in output diff --git a/tests/test_utils/test_rmfile.py b/tests/test_utils/test_rmfile.py new file mode 100644 index 0000000..82564ad --- /dev/null +++ b/tests/test_utils/test_rmfile.py @@ -0,0 +1,79 @@ +""" +rmfile のユニットテスト + +対象パターン(#*#, *cpt, mdout.mdp)のファイルが削除され、 +それ以外のファイルは残ることを確認する。 +""" + +import types + + +from src.utils.rmfile import run + + +class TestRmfile: + def _make_args(self, path): + return types.SimpleNamespace(path=str(path)) + + def test_removes_cpt_files(self, tmp_path): + cpt = tmp_path / "state.cpt" + cpt.touch() + run(self._make_args(tmp_path)) + assert not cpt.exists() + + def test_removes_mdout_mdp(self, tmp_path): + mdout = tmp_path / "mdout.mdp" + mdout.touch() + run(self._make_args(tmp_path)) + assert not mdout.exists() + + def test_removes_backup_files(self, tmp_path): + backup = tmp_path / "#step.trr.1#" + backup.touch() + run(self._make_args(tmp_path)) + assert not backup.exists() + + def test_preserves_other_files(self, tmp_path): + mdp = tmp_path / "md.mdp" + top = tmp_path / "topol.top" + tpr = tmp_path / "topol.tpr" + for f in [mdp, top, tpr]: + f.touch() + + run(self._make_args(tmp_path)) + + assert mdp.exists() + assert top.exists() + assert tpr.exists() + + def test_removes_cpt_recursively(self, tmp_path): + """サブディレクトリ内の .cpt も再帰的に削除されること""" + subdir = tmp_path / "run1" + subdir.mkdir() + cpt = subdir / "state.cpt" + cpt.touch() + + run(self._make_args(tmp_path)) + + assert not cpt.exists() + + def test_mixed_files(self, tmp_path): + """削除対象と保持対象が混在していても正しく処理されること""" + targets = [ + tmp_path / "state.cpt", + tmp_path / "mdout.mdp", + tmp_path / "#backup#", + ] + keepfiles = [ + tmp_path / "topol.top", + tmp_path / "md.mdp", + ] + for f in targets + keepfiles: + f.touch() + + run(self._make_args(tmp_path)) + + for f in targets: + assert not f.exists(), f"{f.name} should have been deleted" + for f in keepfiles: + assert f.exists(), f"{f.name} should be preserved" diff --git a/tests/test_utils/test_shell_hook.py b/tests/test_utils/test_shell_hook.py new file mode 100644 index 0000000..b677aee --- /dev/null +++ b/tests/test_utils/test_shell_hook.py @@ -0,0 +1,53 @@ +""" +utils/shell_hook のユニットテスト + +run() が適切なシェルスクリプトのテンプレートを標準出力に出力することをテストする。 +""" + +import types + + +class TestShellHookRun: + def _make_args(self): + return types.SimpleNamespace() + + def test_output_contains_mdtbx_alias(self, capsys): + """mdtbx の alias / function 定義が含まれること""" + from src.utils.shell_hook import run + + run(self._make_args()) + captured = capsys.readouterr() + assert "mdtbx" in captured.out + + def test_output_contains_pymol(self, capsys): + """pymol の設定が含まれること""" + from src.utils.shell_hook import run + + run(self._make_args()) + captured = capsys.readouterr() + assert "pymol" in captured.out + + def test_output_contains_begin_end_markers(self, capsys): + """BEGIN / END マーカーが含まれること""" + from src.utils.shell_hook import run + + run(self._make_args()) + captured = capsys.readouterr() + assert "BEGIN OF MDTBX SHELL HOOK" in captured.out + assert "END OF MDTBX SHELL HOOK" in captured.out + + def test_output_contains_path_export(self, capsys): + """PATH 設定が含まれること""" + from src.utils.shell_hook import run + + run(self._make_args()) + captured = capsys.readouterr() + assert "PATH" in captured.out + + def test_output_is_nonempty(self, capsys): + """出力が空でないこと""" + from src.utils.shell_hook import run + + run(self._make_args()) + captured = capsys.readouterr() + assert len(captured.out.strip()) > 0 diff --git a/tests/test_utils/test_show_npy.py b/tests/test_utils/test_show_npy.py new file mode 100644 index 0000000..5cdf1b6 --- /dev/null +++ b/tests/test_utils/test_show_npy.py @@ -0,0 +1,62 @@ +""" +utils/show_npy のユニットテスト + +run() が .npy ファイルの内容と shape を標準出力に表示することをテストする。 +""" + +import types + +import numpy as np + + +class TestShowNpyRun: + def _make_args(self, npy_path): + return types.SimpleNamespace(npy=str(npy_path)) + + def test_prints_array_content(self, tmp_path, capsys): + """配列の内容が標準出力に表示されること""" + from src.utils.show_npy import run + + arr = np.array([1.0, 2.0, 3.0]) + npy = tmp_path / "test.npy" + np.save(str(npy), arr) + + run(self._make_args(npy)) + captured = capsys.readouterr() + assert "1." in captured.out or "1.0" in captured.out + + def test_prints_shape(self, tmp_path, capsys): + """配列の shape が標準出力に表示されること""" + from src.utils.show_npy import run + + arr = np.zeros((3, 4)) + npy = tmp_path / "shape_test.npy" + np.save(str(npy), arr) + + run(self._make_args(npy)) + captured = capsys.readouterr() + assert "(3, 4)" in captured.out + + def test_2d_array(self, tmp_path, capsys): + """2D 配列でも正常に動作すること""" + from src.utils.show_npy import run + + arr = np.arange(6).reshape(2, 3) + npy = tmp_path / "2d.npy" + np.save(str(npy), arr) + + run(self._make_args(npy)) + captured = capsys.readouterr() + assert "(2, 3)" in captured.out + + def test_scalar_array(self, tmp_path, capsys): + """スカラー配列でも正常に動作すること""" + from src.utils.show_npy import run + + arr = np.float64(42.0) + npy = tmp_path / "scalar.npy" + np.save(str(npy), arr) + + run(self._make_args(npy)) + captured = capsys.readouterr() + assert len(captured.out.strip()) > 0