Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,31 @@ def setUpClass(cls):
if env_coverage.value < cls.TEST_LEVEL.value:
raise unittest.SkipTest(f"Skipped test : Test Coverage {env_coverage.name} < {cls.TEST_LEVEL.name}")

if os.path.exists(cls.get_rbln_local_dir()):
shutil.rmtree(cls.get_rbln_local_dir())

cls.model = cls.RBLN_CLASS.from_pretrained(
cls.HF_MODEL_ID,
model_save_dir=cls.get_rbln_local_dir(),
rbln_device=cls.DEVICE,
**cls.RBLN_CLASS_KWARGS,
**cls.HF_CONFIG_KWARGS,
)
REUSE_ARTIFACTS_PATH = os.environ.get("REUSE_ARTIFACTS_PATH", None)
if REUSE_ARTIFACTS_PATH is None:
if os.path.exists(cls.get_rbln_local_dir()):
shutil.rmtree(cls.get_rbln_local_dir())

cls.model = cls.RBLN_CLASS.from_pretrained(
cls.HF_MODEL_ID,
model_save_dir=cls.get_rbln_local_dir(),
rbln_device=cls.DEVICE,
**cls.RBLN_CLASS_KWARGS,
**cls.HF_CONFIG_KWARGS,
)
else:
if os.path.exists(REUSE_ARTIFACTS_PATH):
compiled_model_path = os.path.join(REUSE_ARTIFACTS_PATH, cls.get_rbln_local_dir())
if os.path.exists(compiled_model_path):
cls.model = cls.RBLN_CLASS.from_pretrained(compiled_model_path)
else:
raise ValueError(f"Compiled model not found at: {compiled_model_path}")
else:
raise ValueError(f"REUSE_ARTIFACTS_PATH does not exist: {REUSE_ARTIFACTS_PATH}")

@classmethod
def get_rbln_local_dir(cls):
return os.path.basename(cls.HF_MODEL_ID) + "-local"
return os.path.basename(cls.__name__) + "-artifact"

@classmethod
def get_hf_auto_class(cls):
Expand All @@ -173,8 +184,16 @@ def is_diffuser(self):

@classmethod
def tearDownClass(cls):
if os.path.exists(cls.get_rbln_local_dir()):
shutil.rmtree(cls.get_rbln_local_dir())
rbln_local_dir = cls.get_rbln_local_dir()
if not os.path.exists(rbln_local_dir):
return

SAVE_ARTIFACTS_PATH = os.environ.get("SAVE_ARTIFACTS_PATH", None)
if SAVE_ARTIFACTS_PATH is None:
shutil.rmtree(rbln_local_dir)
else:
os.makedirs(SAVE_ARTIFACTS_PATH, exist_ok=True)
shutil.move(rbln_local_dir, os.path.join(SAVE_ARTIFACTS_PATH, rbln_local_dir))

def test_model_save_dir(self):
self.assertTrue(os.path.exists(self.get_rbln_local_dir()), "model_save_dir does not work.")
Expand Down Expand Up @@ -240,10 +259,16 @@ def test_save_load(self):
self._inner_test_save_load(tmpdir)

def test_model_save_dir_load(self):
REUSE_ARTIFACTS_PATH = os.environ.get("REUSE_ARTIFACTS_PATH", None)
if REUSE_ARTIFACTS_PATH is None:
rbln_local_dir = self.get_rbln_local_dir()
else:
rbln_local_dir = os.path.join(REUSE_ARTIFACTS_PATH, self.get_rbln_local_dir())

with ContextRblnConfig(create_runtimes=False):
# Test model_save_dir
_ = self.RBLN_CLASS.from_pretrained(
self.get_rbln_local_dir(),
rbln_local_dir,
rbln_create_runtimes=False,
**self.HF_CONFIG_KWARGS,
)
Expand Down
Loading