Skip to content

Commit da53057

Browse files
committed
edited some comments
1 parent 85b5e2b commit da53057

2 files changed

Lines changed: 9 additions & 15 deletions

File tree

examples/cat_breeds/models/for_cats.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import torch
21
import torch.nn as nn
3-
import torch.nn.functional as F
42

53
class SmallCNN(nn.Module):
64
def __init__(self):
@@ -17,8 +15,4 @@ def __init__(self):
1715
)
1816

1917
def forward(self, x):
20-
return self.layers(x)
21-
22-
# NOTES
23-
# Must use nn.Flatten instead of x.view()
24-
# Must use pool1 and pool2 rather than one pool.
18+
return self.layers(x)

examples/cat_breeds/to_chai.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
model = torch.load("./cat_breeds/models/pretest.pt")
1111
model.chai_dump("./cat_breeds/models/chai_model", "SmallCNN")
1212

13-
# load_path = "./cat_breeds/data/catbreeds/images"
14-
# for i, item in enumerate(os.listdir(load_path)):
15-
# if "item" in item: # check file name
16-
# img = np.load(f"{load_path}/{item}")
17-
# img = torch.Tensor(img)
18-
# img.chai_save("./cat_breeds/data/catbreeds/chai_images", f"item{i}", verbose=False)
19-
# if i > 20:
20-
# break
13+
load_path = "./cat_breeds/data/catbreeds/images"
14+
for i, item in enumerate(os.listdir(load_path)):
15+
if "item" in item: # check file name
16+
img = np.load(f"{load_path}/{item}")
17+
img = torch.Tensor(img)
18+
img.chai_save("./cat_breeds/data/catbreeds/chai_images", f"item{i}", verbose=False)
19+
if i > 20:
20+
break

0 commit comments

Comments
 (0)