Skip to content

Commit 3558104

Browse files
authored
Merge pull request #12 from NVIDIA/bbonev/v0.1.1
bugfixes addressing issues with imports
2 parents 974661f + 80ed4b4 commit 3558104

File tree

6 files changed

+9
-5
lines changed

6 files changed

+9
-5
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
- name: Install dependencies
1717
run: |
1818
python -m pip install --upgrade pip setuptools wheel
19-
python -m pip install tqdm numpy parameterized xarray xskillscore timm jsbeautifier pynvml h5py wandb ruamel.yaml moviepy tensorly tensorly-torch
19+
python -m pip install tqdm numpy numba parameterized xarray xskillscore timm jsbeautifier pynvml h5py wandb ruamel.yaml moviepy tensorly tensorly-torch
2020
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
2121
python -m pip install torch_harmonics
2222
python -m pip install git+https://github.com/NVIDIA/modulus.git

makani/inference.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
parser.add_argument("--epsilon_factor", default=0, type=float)
5454
parser.add_argument("--split_data_channels", action="store_true")
5555
parser.add_argument("--mode", default="score", type=str, choices=["score", "ensemble"], help="Select inference mode")
56+
parser.add_argument("--enable_odirect", action="store_true")
5657

5758
# checkpoint format
5859
parser.add_argument("--checkpoint_format", default="legacy", choices=["legacy", "flexible"], type=str, help="Format in which to load checkpoints.")
@@ -124,6 +125,7 @@
124125
params["amp_mode"] = args.amp_mode
125126
params["jit_mode"] = args.jit_mode
126127
params["cuda_graph_mode"] = args.cuda_graph_mode
128+
params["enable_odirect"] = args.enable_odirect
127129
params["enable_benchy"] = args.enable_benchy
128130
params["disable_ddp"] = args.disable_ddp
129131
params["enable_nhwc"] = args.enable_nhwc

makani/utils/comm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def is_distributed(name: str):
9393
return False
9494

9595

96-
# initialization routine
96+
# initialization routine
9797
def init(model_parallel_sizes=[1, 1, 1, 1],
9898
model_parallel_names=["h", "w", "fin", "fout"],
9999
verbose=False):

makani/utils/dataloader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def get_dataloader(params, files_pattern, device, train=True, final_eval=False):
7272
from makani.utils.dataloaders.data_loader_multifiles import MultifilesDataset as MultifilesDataset2D
7373
from torch.utils.data.distributed import DistributedSampler
7474

75-
# multifiles dataset
75+
# multifiles
7676
dataset = MultifilesDataset2D(params, files_pattern, train)
7777

7878
sampler = DistributedSampler(dataset, shuffle=train, num_replicas=params.data_num_shards, rank=params.data_shard_id) if (params.data_num_shards > 1) else None
@@ -81,8 +81,8 @@ def get_dataloader(params, files_pattern, device, train=True, final_eval=False):
8181
dataset,
8282
batch_size=int(params.batch_size),
8383
num_workers=params.num_data_workers,
84-
shuffle=False, # (sampler is None),
85-
sampler=sampler if train else None,
84+
shuffle=(sampler is None) and train,
85+
sampler=sampler,
8686
drop_last=True,
8787
pin_memory=torch.cuda.is_available(),
8888
)

makani/utils/visualize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import os
1617
import io
1718
import numpy as np
1819
import concurrent.futures as cf

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ dynamic = ["version"]
4949
dependencies = [
5050
"torch>=2.0.0",
5151
"numpy>=1.22.4,<1.25",
52+
"numba>=0.50.0",
5253
"nvidia_dali_cuda110>=1.16.0",
5354
"nvidia-modulus>=0.5.0a0",
5455
"torch-harmonics>=0.6.5",

0 commit comments

Comments
 (0)