Skip to content

Commit 1625986

Browse files
committed
added change to fix device num checkpoint
1 parent 0a452a0 commit 1625986

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

kan/MultKAN.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,9 @@ def saveckpt(self, path='model'):
534534
round = model.round,
535535
device = str(model.device)
536536
)
537+
538+
if dic["device"].isdigit():
539+
dic["device"] = int(model.device)
537540

538541
for i in range (model.depth):
539542
dic[f'symbolic.funs_name.{i}'] = model.symbolic_fun[i].funs_name

0 commit comments

Comments
 (0)