Skip to content

Commit 0bda800

Browse files
authored
implement lazy loading (#316)
1 parent 25ca5a8 commit 0bda800

File tree

17 files changed

+285
-181
lines changed

17 files changed

+285
-181
lines changed

convokit/__init__.py

Lines changed: 102 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,104 @@
11
import warnings
2+
from typing import Any
23

3-
try:
4-
from .model import *
5-
from .util import *
6-
from .coordination import *
7-
from .politenessStrategies import *
8-
from .transformer import *
9-
from .convokitPipeline import *
10-
from .hyperconvo import *
11-
from .speakerConvoDiversity import *
12-
from .text_processing import *
13-
from .phrasing_motifs import *
14-
from .prompt_types import *
15-
from .classifier.classifier import *
16-
from .ranker import *
17-
from .forecaster import *
18-
from .fighting_words import *
19-
from .paired_prediction import *
20-
from .bag_of_words import *
21-
from .expected_context_framework import *
22-
from .surprise import *
23-
from .convokitConfig import *
24-
from .redirection import *
25-
from .pivotal_framework import *
26-
from .utterance_simulator import *
27-
except ModuleNotFoundError as e:
28-
# Don't print ModuleNotFoundError messages as they're handled by individual modules
29-
if "not currently installed" not in str(e):
30-
print(f"An error occurred: {e}")
31-
warnings.warn(
32-
"If you are using ConvoKit with Google Colab, incorrect versions of some packages (ex. scipy) may be imported while runtime start. To fix the issue, restart the session and run all codes again. Thank you!"
33-
)
34-
except Exception as e:
35-
print(f"An error occurred: {e}")
36-
warnings.warn(
37-
"If you are using ConvoKit with Google Colab, incorrect versions of some packages (ex. scipy) may be imported while runtime start. To fix the issue, restart the session and run all codes again. Thank you!"
38-
)
39-
40-
41-
# __path__ = __import__('pkgutil').extend_path(__path__, __name__)
4+
# Core modules - always imported immediately
5+
from .model import *
6+
from .util import *
7+
from .transformer import *
8+
from .convokitConfig import *
9+
from .convokitPipeline import *
10+
11+
# Module mapping for lazy loading
12+
# Each entry maps module_name -> import_path
13+
_LAZY_MODULES = {
14+
"coordination": ".coordination",
15+
"politenessStrategies": ".politenessStrategies",
16+
"hyperconvo": ".hyperconvo",
17+
"speakerConvoDiversity": ".speakerConvoDiversity",
18+
"text_processing": ".text_processing",
19+
"phrasing_motifs": ".phrasing_motifs",
20+
"prompt_types": ".prompt_types",
21+
"classifier": ".classifier",
22+
"ranker": ".ranker",
23+
"forecaster": ".forecaster",
24+
"fighting_words": ".fighting_words",
25+
"paired_prediction": ".paired_prediction",
26+
"bag_of_words": ".bag_of_words",
27+
"expected_context_framework": ".expected_context_framework",
28+
"surprise": ".surprise",
29+
"redirection": ".redirection",
30+
"pivotal_framework": ".pivotal_framework",
31+
"utterance_simulator": ".utterance_simulator",
32+
"utterance_likelihood": ".utterance_likelihood",
33+
"speaker_convo_helpers": ".speaker_convo_helpers",
34+
"politeness_collections": ".politeness_collections",
35+
}
36+
37+
# Cache for loaded modules
38+
_loaded_modules = {}
39+
40+
41+
def _lazy_import(module_name: str) -> Any:
42+
"""Import a module lazily and cache the result."""
43+
if module_name in _loaded_modules:
44+
return _loaded_modules[module_name]
45+
46+
if module_name not in _LAZY_MODULES:
47+
raise AttributeError(f"module '{__name__}' has no attribute '{module_name}'")
48+
49+
import_path = _LAZY_MODULES[module_name]
50+
51+
try:
52+
import importlib
53+
54+
module = importlib.import_module(import_path, package=__name__)
55+
_loaded_modules[module_name] = module
56+
57+
globals_dict = globals()
58+
if hasattr(module, "__all__"):
59+
for name in module.__all__:
60+
if hasattr(module, name):
61+
globals_dict[name] = getattr(module, name)
62+
else:
63+
for name in dir(module):
64+
if not name.startswith("_"):
65+
globals_dict[name] = getattr(module, name)
66+
67+
return module
68+
69+
except Exception as e:
70+
# Simply re-raise whatever the module throws
71+
# Let each module handle its own error messaging
72+
raise
73+
74+
75+
def __getattr__(name: str) -> Any:
76+
"""Handle attribute access for lazy-loaded modules."""
77+
# Check if it's a module we can lazy load
78+
if name in _LAZY_MODULES:
79+
return _lazy_import(name)
80+
81+
# Check if it's an exported symbol from a lazy module
82+
# We need to check each module to see if it exports this symbol
83+
for module_name in _LAZY_MODULES:
84+
if module_name not in _loaded_modules:
85+
# Try to import the module to see if it has the requested attribute
86+
try:
87+
import importlib
88+
89+
import_path = _LAZY_MODULES[module_name]
90+
module = importlib.import_module(import_path, package=__name__)
91+
92+
# Check if this module has the requested attribute
93+
if hasattr(module, name):
94+
# Import the full module (which will add all symbols to globals)
95+
_lazy_import(module_name)
96+
# Return the requested attribute
97+
return getattr(module, name)
98+
99+
except Exception:
100+
# If module fails to import, just skip it and try next module
101+
# The module's own error handling will take care of proper error messages
102+
continue
103+
104+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

convokit/forecaster/TransformerDecoderModel.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,10 @@
1-
try:
2-
import unsloth
3-
from unsloth import FastLanguageModel, is_bfloat16_supported
4-
from unsloth.chat_templates import get_chat_template
5-
import torch
6-
import torch.nn.functional as F
7-
from trl import SFTTrainer, SFTConfig
8-
from datasets import Dataset
9-
10-
UNSLOTH_AVAILABLE = True
11-
except (ModuleNotFoundError, ImportError) as e:
12-
if "Unsloth GPU requirement not met" in str(e):
13-
raise ImportError("Unsloth GPU requirement not met") from e
14-
else:
15-
raise ModuleNotFoundError(
16-
"unsloth, torch, trl, or datasets is not currently installed. Run 'pip install convokit[llm]' if you would like to use the TransformerDecoderModel."
17-
) from e
1+
import unsloth
2+
from unsloth import FastLanguageModel, is_bfloat16_supported
3+
from unsloth.chat_templates import get_chat_template
4+
import torch
5+
import torch.nn.functional as F
6+
from trl import SFTTrainer, SFTConfig
7+
from datasets import Dataset
188

199
import json
2010
import os

convokit/forecaster/TransformerEncoderModel.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,13 @@
1-
try:
2-
import torch
3-
import torch.nn.functional as F
4-
from datasets import Dataset, DatasetDict
5-
from transformers import (
6-
AutoConfig,
7-
AutoModelForSequenceClassification,
8-
AutoTokenizer,
9-
TrainingArguments,
10-
Trainer,
11-
)
12-
13-
TRANSFORMERS_AVAILABLE = True
14-
except (ModuleNotFoundError, ImportError) as e:
15-
raise ModuleNotFoundError(
16-
"torch, transformers, or datasets is not currently installed. Run 'pip install convokit[llm]' if you would like to use the TransformerEncoderModel."
17-
) from e
1+
import torch
2+
import torch.nn.functional as F
3+
from datasets import Dataset, DatasetDict
4+
from transformers import (
5+
AutoConfig,
6+
AutoModelForSequenceClassification,
7+
AutoTokenizer,
8+
TrainingArguments,
9+
Trainer,
10+
)
1811

1912
import os
2013
import pandas as pd

convokit/forecaster/__init__.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,36 @@
1111
# Import Transformer models with proper error handling
1212
try:
1313
from .TransformerDecoderModel import *
14-
except ImportError as e:
14+
except (ImportError, ModuleNotFoundError) as e:
1515
if "Unsloth GPU requirement not met" in str(e):
16-
print(
16+
raise ImportError(
1717
"Error from Unsloth: NotImplementedError: Unsloth currently only works on NVIDIA GPUs and Intel GPUs."
18-
)
19-
elif "not currently installed" in str(e):
20-
print(
18+
) from e
19+
elif (
20+
"not currently installed" in str(e)
21+
or "torch" in str(e)
22+
or "unsloth" in str(e)
23+
or "trl" in str(e)
24+
or "datasets" in str(e)
25+
):
26+
raise ImportError(
2127
"TransformerDecoderModel requires ML dependencies. Run 'pip install convokit[llm]' to install them."
22-
)
28+
) from e
2329
else:
2430
raise
2531

2632
try:
2733
from .TransformerEncoderModel import *
28-
except ImportError as e:
29-
if "not currently installed" in str(e):
30-
print(
34+
except (ImportError, ModuleNotFoundError) as e:
35+
if (
36+
"not currently installed" in str(e)
37+
or "torch" in str(e)
38+
or "transformers" in str(e)
39+
or "datasets" in str(e)
40+
):
41+
raise ImportError(
3142
"TransformerEncoderModel requires ML dependencies. Run 'pip install convokit[llm]' to install them."
32-
)
43+
) from e
3344
else:
3445
raise
3546

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
try:
22
from .pivotal import *
3-
except ImportError as e:
3+
except (ImportError, ModuleNotFoundError) as e:
44
if "Unsloth GPU requirement not met" in str(e):
5-
print(
5+
raise ImportError(
66
"Error from Unsloth: NotImplementedError: Unsloth currently only works on NVIDIA GPUs and Intel GPUs."
7-
)
8-
elif "not currently installed" in str(e):
9-
print(
7+
) from e
8+
elif (
9+
"not currently installed" in str(e)
10+
or "torch" in str(e)
11+
or "unsloth" in str(e)
12+
or "transformers" in str(e)
13+
):
14+
raise ImportError(
1015
"Pivotal framework requires ML dependencies. Run 'pip install convokit[llm]' to install them."
11-
)
16+
) from e
1217
else:
1318
raise

convokit/pivotal_framework/pivotal.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88
from convokit.forecaster.forecasterModel import ForecasterModel
99
from convokit.forecaster.forecaster import Forecaster
1010

11-
try:
12-
from convokit.utterance_simulator.utteranceSimulatorModel import UtteranceSimulatorModel
13-
except NotImplementedError as e:
14-
raise ImportError("Unsloth GPU requirement not met") from e
11+
from convokit.utterance_simulator.utteranceSimulatorModel import UtteranceSimulatorModel
1512
from convokit.utterance_simulator.utteranceSimulator import UtteranceSimulator
1613
from .util import ContextTuple, DEFAULT_LABELER
1714

convokit/redirection/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
from .redirection import *
1+
try:
2+
from .redirection import *
3+
except (ImportError, ModuleNotFoundError) as e:
4+
if "torch" in str(e) or "not currently installed" in str(e):
5+
raise ImportError(
6+
"Redirection module requires ML dependencies. Run 'pip install convokit[llm]' to install them."
7+
) from e
8+
else:
9+
raise
210

311
try:
412
from .likelihoodModel import *
5-
except ImportError as e:
13+
except (ImportError, ModuleNotFoundError) as e:
614
if "not currently installed" in str(e):
7-
print(
15+
raise ImportError(
816
"LikelihoodModel requires ML dependencies. Run 'pip install convokit[llm]' to install them."
9-
)
17+
) from e
1018
else:
1119
raise

convokit/redirection/config.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
1-
try:
2-
from peft import LoraConfig
3-
from transformers import BitsAndBytesConfig
4-
import torch
5-
6-
REDIRECTION_ML_AVAILABLE = True
7-
except (ModuleNotFoundError, ImportError) as e:
8-
raise ModuleNotFoundError(
9-
"peft, transformers, or torch is not currently installed. Run 'pip install convokit[llm]' if you would like to use the redirection module."
10-
) from e
1+
from peft import LoraConfig
2+
from transformers import BitsAndBytesConfig
3+
import torch
114

125
DEFAULT_BNB_CONFIG = BitsAndBytesConfig(
136
load_in_4bit=True,

convokit/redirection/gemmaLikelihoodModel.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,13 @@
1-
try:
2-
import torch
3-
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModel
4-
from transformers import (
5-
AutoTokenizer,
6-
AutoModelForCausalLM,
7-
BitsAndBytesConfig,
8-
DataCollatorForLanguageModeling,
9-
TrainingArguments,
10-
)
11-
from trl import SFTTrainer
12-
13-
GEMMA_ML_AVAILABLE = True
14-
except (ModuleNotFoundError, ImportError) as e:
15-
raise ModuleNotFoundError(
16-
"torch, peft, transformers, or trl is not currently installed. Run 'pip install convokit[llm]' if you would like to use the GemmaLikelihoodModel."
17-
) from e
1+
import torch
2+
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModel
3+
from transformers import (
4+
AutoTokenizer,
5+
AutoModelForCausalLM,
6+
BitsAndBytesConfig,
7+
DataCollatorForLanguageModeling,
8+
TrainingArguments,
9+
)
10+
from trl import SFTTrainer
1811

1912
from .likelihoodModel import LikelihoodModel
2013
from .config import DEFAULT_TRAIN_CONFIG, DEFAULT_BNB_CONFIG, DEFAULT_LORA_CONFIG

convokit/redirection/preprocessing.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,4 @@
1-
try:
2-
from datasets import Dataset
3-
4-
DATASETS_AVAILABLE = True
5-
except (ModuleNotFoundError, ImportError) as e:
6-
raise ModuleNotFoundError(
7-
"datasets is not currently installed. Run 'pip install convokit[llm]' if you would like to use the redirection preprocessing functionality."
8-
) from e
1+
from datasets import Dataset
92

103

114
def default_speaker_prefixes(roles):

0 commit comments

Comments
 (0)