diff --git a/tests/test_base.py b/tests/test_base.py index c27ba3694..13729e7ac 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -142,20 +142,31 @@ def setUpClass(cls): if env_coverage.value < cls.TEST_LEVEL.value: raise unittest.SkipTest(f"Skipped test : Test Coverage {env_coverage.name} < {cls.TEST_LEVEL.name}") - if os.path.exists(cls.get_rbln_local_dir()): - shutil.rmtree(cls.get_rbln_local_dir()) - - cls.model = cls.RBLN_CLASS.from_pretrained( - cls.HF_MODEL_ID, - model_save_dir=cls.get_rbln_local_dir(), - rbln_device=cls.DEVICE, - **cls.RBLN_CLASS_KWARGS, - **cls.HF_CONFIG_KWARGS, - ) + REUSE_ARTIFACTS_PATH = os.environ.get("REUSE_ARTIFACTS_PATH", None) + if REUSE_ARTIFACTS_PATH is None: + if os.path.exists(cls.get_rbln_local_dir()): + shutil.rmtree(cls.get_rbln_local_dir()) + + cls.model = cls.RBLN_CLASS.from_pretrained( + cls.HF_MODEL_ID, + model_save_dir=cls.get_rbln_local_dir(), + rbln_device=cls.DEVICE, + **cls.RBLN_CLASS_KWARGS, + **cls.HF_CONFIG_KWARGS, + ) + else: + if os.path.exists(REUSE_ARTIFACTS_PATH): + compiled_model_path = os.path.join(REUSE_ARTIFACTS_PATH, cls.get_rbln_local_dir()) + if os.path.exists(compiled_model_path): + cls.model = cls.RBLN_CLASS.from_pretrained(compiled_model_path) + else: + raise ValueError(f"Compiled model not found at: {compiled_model_path}") + else: + raise ValueError(f"REUSE_ARTIFACTS_PATH does not exist: {REUSE_ARTIFACTS_PATH}") @classmethod def get_rbln_local_dir(cls): - return os.path.basename(cls.HF_MODEL_ID) + "-local" + return os.path.basename(cls.__name__) + "-artifact" @classmethod def get_hf_auto_class(cls): @@ -173,8 +184,16 @@ def is_diffuser(self): @classmethod def tearDownClass(cls): - if os.path.exists(cls.get_rbln_local_dir()): - shutil.rmtree(cls.get_rbln_local_dir()) + rbln_local_dir = cls.get_rbln_local_dir() + if not os.path.exists(rbln_local_dir): + return + + SAVE_ARTIFACTS_PATH = os.environ.get("SAVE_ARTIFACTS_PATH", None) + if SAVE_ARTIFACTS_PATH is None: + shutil.rmtree(rbln_local_dir) + else: + os.makedirs(SAVE_ARTIFACTS_PATH, exist_ok=True) + shutil.move(rbln_local_dir, os.path.join(SAVE_ARTIFACTS_PATH, rbln_local_dir)) def test_model_save_dir(self): self.assertTrue(os.path.exists(self.get_rbln_local_dir()), "model_save_dir does not work.") @@ -240,10 +259,16 @@ def test_save_load(self): self._inner_test_save_load(tmpdir) def test_model_save_dir_load(self): + REUSE_ARTIFACTS_PATH = os.environ.get("REUSE_ARTIFACTS_PATH", None) + if REUSE_ARTIFACTS_PATH is None: + rbln_local_dir = self.get_rbln_local_dir() + else: + rbln_local_dir = os.path.join(REUSE_ARTIFACTS_PATH, self.get_rbln_local_dir()) + with ContextRblnConfig(create_runtimes=False): # Test model_save_dir _ = self.RBLN_CLASS.from_pretrained( - self.get_rbln_local_dir(), + rbln_local_dir, rbln_create_runtimes=False, **self.HF_CONFIG_KWARGS, )