Skip to content

Fix: long import times #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import time
import logging
import sys
import tempfile
from .utils import Version, is_main_process
import triton
from .peft_utils import get_lora_layer_modules
Expand Down Expand Up @@ -63,7 +64,7 @@ def filter(self, x): return not (self.text in x.getMessage())
global UNSLOTH_COMPILE_LOCATION
global UNSLOTH_CREATED_FUNCTIONS
COMBINED_UNSLOTH_NAME = "unsloth_compiled_module"
UNSLOTH_COMPILE_LOCATION = "unsloth_compiled_cache"
UNSLOTH_COMPILE_LOCATION = os.path.join(tempfile.gettempdir(), "unsloth_compiled_cache")
UNSLOTH_CREATED_FUNCTIONS = []


Expand Down Expand Up @@ -234,9 +235,6 @@ def create_new_function(
new_source = imports + "\n\n" + new_source
new_source = prepend + new_source + append

# Fix super() Not necessary anymore!
# new_source = new_source.replace("super()", "super(type(self), self)")

# Check location
if is_main_process():
if not os.path.exists(UNSLOTH_COMPILE_LOCATION):
Expand All @@ -259,32 +257,44 @@ def create_new_function(
if overwrite or not os.path.isfile(function_location):
while not os.path.isfile(function_location): continue
pass

# Try loading new module


# First try adding to sys.path and using import_module
new_module = None
while True:
old_path = None

try:
# Add directory to sys.path temporarily if it's not already there
if UNSLOTH_COMPILE_LOCATION not in sys.path:
old_path = list(sys.path)
sys.path.insert(0, UNSLOTH_COMPILE_LOCATION)

# Try standard import
new_module = importlib.import_module(name)
except Exception as e:
print(f"Standard import failed for {name}: {e}")

# Fallback to direct module loading
try:
new_module = importlib.import_module(UNSLOTH_COMPILE_LOCATION + "." + name)
break
except:
# Instead use sys modules for dynamic loading
module_name = f"unsloth_cache_{name}"
file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, name) + ".py"
spec = importlib.util.spec_from_file_location(module_name, file_location)
new_module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = new_module
spec.loader.exec_module(new_module)

time.sleep(0.01)
pass
pass
except Exception as e:
print(f"Direct module loading failed for {name}: {e}")
finally:
# Restore original sys.path if we modified it
if old_path is not None:
sys.path = old_path

if new_module is None:
raise ImportError(f'Unsloth: Cannot import {UNSLOTH_COMPILE_LOCATION + "." + name}')
raise ImportError(f'Unsloth: Cannot import {name} from {UNSLOTH_COMPILE_LOCATION}')

# Must save to global state or else temp file closes
UNSLOTH_CREATED_FUNCTIONS.append(location)
return new_module
pass


def create_standalone_class(
Expand Down