Skip to content

Commit 13bac0c

Browse files
committed
Codebase release
1 parent 2d0f26e commit 13bac0c

22 files changed

Lines changed: 5951 additions & 5 deletions

.DS_Store

6 KB
Binary file not shown.

README.md

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,56 @@ Our model was rigorously validated on the public CAMUS and EchoNet-Dynamic datas
2929
3030
---
3131

32-
## 📦 Code Availability
33-
34-
> 🚧 Code coming soon...
35-
36-
We are currently in the process of cleaning up and preparing the codebase and model weights for public release. We aim to make them available in this repository to facilitate further research and reproducibility. Please stay tuned for updates!
32+
## ⚙️ Setup & Installation
33+
34+
1. **Clone the repository:**
35+
```bash
36+
git clone https://github.com/abdur75648/Echo-DND.git
37+
cd Echo-DND
38+
```
39+
40+
2. **Install Dependencies:**
41+
```bash
42+
pip install -r requirements.txt
43+
```
44+
45+
3. **Prepare the Dataset:**
46+
- Download the CAMUS and EchoNet-Dynamic datasets.
47+
- Organize them into a root data directory with the following structure:
48+
```
49+
<your_data_root_dir>/
50+
├── CAMUS/
51+
│ ├── patient0001/
52+
│ │ ├── patient0001_4CH_ED.mhd
53+
│ │ ├── patient0001_4CH_ED_gt.mhd
54+
│ │ └── ... (other patient files)
55+
│ └── ... (other patient folders)
56+
└── EchoNet-Dynamic/
57+
├── Train/
58+
│ ├── Frames/
59+
│ │ └── 0X100037609D9A4939_image0001.png
60+
│ └── Masks/
61+
│ └── 0X100037609D9A4939_image0001.png (corresponding mask)
62+
├── Val/
63+
└── ...
64+
```
65+
- The `echo_dnd_dataset.py` script is configured to load data assuming this structure.
3766

3867
---
3968

69+
## 🏃‍♂️ Training & Inference
70+
### Training
71+
To train the Echo-DND model, run the following command:
72+
```bash
73+
python training_echo_dnd.py --data_dir /path/to/your_data_root_dir --batch_size 4 --lr 1e-4 --out_dir ./results/training_run1
74+
```
75+
76+
### Inference
77+
To perform inference on a single image, use the following command:
78+
```bash
79+
python inference_echo_dnd.py --image_path /path/to/your/test_image.png --model_path /path/to/your/pretrained_echodnd_model.pt --out_dir ./results/inference_output
80+
```
81+
4082
## 📄 Citation
4183

4284
If you find this work useful, please consider citing:

guided_diffusion/.DS_Store

6 KB
Binary file not shown.

guided_diffusion/__init__.py

Whitespace-only changes.

guided_diffusion/dist_util.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""
2+
Helpers for distributed training.
3+
"""
4+
5+
import io
6+
import os
7+
import socket
8+
9+
import blobfile as bf
10+
#from mpi4py import MPI
11+
import torch as th
12+
import torch.distributed as dist
13+
14+
# Change this to reflect your cluster layout.
15+
# The GPU for a given rank is (rank % GPUS_PER_NODE).
16+
GPUS_PER_NODE = 8
17+
18+
SETUP_RETRY_COUNT = 3
19+
20+
21+
def setup_dist(args):
22+
"""
23+
Setup a distributed process group.
24+
"""
25+
if dist.is_initialized():
26+
return
27+
if not args.multi_gpu:
28+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_dev
29+
30+
backend = "gloo" if not th.cuda.is_available() else "nccl"
31+
32+
if backend == "gloo":
33+
hostname = "localhost"
34+
else:
35+
hostname = socket.gethostbyname(socket.getfqdn())
36+
os.environ["MASTER_ADDR"] = '127.0.1.1'#comm.bcast(hostname, root=0)
37+
os.environ["RANK"] = '0'#str(comm.rank)
38+
os.environ["WORLD_SIZE"] = '1'#str(comm.size)
39+
40+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
41+
s.bind(("", 0))
42+
s.listen(1)
43+
port = s.getsockname()[1]
44+
s.close()
45+
os.environ["MASTER_PORT"] = str(port)
46+
dist.init_process_group(backend=backend, init_method="env://")
47+
48+
def dev():
49+
"""
50+
Get the device to use for torch.distributed.
51+
"""
52+
if th.cuda.is_available():
53+
return th.device(f"cuda")
54+
return th.device("cpu")
55+
56+
57+
def load_state_dict(path, **kwargs):
58+
"""
59+
Load a PyTorch file without redundant fetches across MPI ranks.
60+
"""
61+
mpigetrank=0
62+
if mpigetrank==0:
63+
with bf.BlobFile(path, "rb") as f:
64+
data = f.read()
65+
else:
66+
data = None
67+
68+
return th.load(io.BytesIO(data), **kwargs)
69+
70+
71+
def sync_params(params):
72+
"""
73+
Synchronize a sequence of Tensors across ranks from rank 0.
74+
"""
75+
for p in params:
76+
with th.no_grad():
77+
dist.broadcast(p, 0)
78+
79+
80+
def _find_free_port():
81+
try:
82+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
83+
s.bind(("", 0))
84+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
85+
return s.getsockname()[1]
86+
finally:
87+
s.close()

0 commit comments

Comments
 (0)