Skip to content

Commit d8a96d9

Browse files
Added more detailed logging.
1 parent cb70e5f commit d8a96d9

File tree

7 files changed

+40
-14
lines changed

7 files changed

+40
-14
lines changed

src/fmcore/algorithm/huggingface/transformers.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import gc
2+
import os
23
import time
34
from abc import ABC
45
from collections import OrderedDict
@@ -35,6 +36,8 @@
3536
from fmcore.constants import MLType
3637

3738
with optional_dependency("torch", "sentencepiece", "transformers", "tokenizers", "huggingface_hub"):
39+
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "300"
40+
3841
import huggingface_hub
3942
import torch
4043
from torch import Tensor
@@ -70,14 +73,6 @@
7073
)
7174

7275
from fmcore.framework import Dataset
73-
from fmcore.framework.dl.torch import (
74-
Loss,
75-
Optimizer,
76-
PyTorch,
77-
PyTorchBaseModel,
78-
PyTorchClassifierMixin,
79-
PyTorchMultiLabelClassifierMixin,
80-
)
8176
from fmcore.framework._task.text_generation import (
8277
GENERATED_TEXTS_COL,
8378
GenerationOutputScoresFormat,
@@ -86,6 +81,14 @@
8681
TextGenerationParams,
8782
TextGenerationParamsMapper,
8883
)
84+
from fmcore.framework.dl.torch import (
85+
Loss,
86+
Optimizer,
87+
PyTorch,
88+
PyTorchBaseModel,
89+
PyTorchClassifierMixin,
90+
PyTorchMultiLabelClassifierMixin,
91+
)
8992

9093
def mapping_to_auto_model_classes(mapping_names: Union[List, Dict, OrderedDict]) -> Dict[str, str]:
9194
if isinstance(mapping_names, (dict, OrderedDict)):

src/fmcore/algorithm/vllm.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class Hyperparameters(GenerativeLM.Hyperparameters):
2929
gpu_memory_utilization: confloat(gt=0.0, le=1.0) = 0.95
3030
max_model_len: conint(ge=1)
3131
generation_params: Union[TextGenerationParams, Dict, str]
32+
api_key: Optional[str] = None
3233

3334
@model_validator(mode="before")
3435
@classmethod
@@ -46,14 +47,25 @@ def set_params(cls, params: Dict) -> Dict:
4647
params,
4748
param="max_model_len",
4849
alias=[
50+
"max_length",
4951
"max_len",
50-
"max_model_len",
5152
"max_sequence_length",
5253
"max_sequence_len",
5354
"max_input_length",
5455
"max_input_len",
56+
"max_model_length",
57+
"max_model_len",
58+
],
59+
)
60+
set_param_from_alias(
61+
params,
62+
param="api_key",
63+
alias=[
64+
"token",
65+
"api_token",
5566
],
5667
)
68+
5769
params["generation_params"] = TextGenerationParamsMapper.of(
5870
params["generation_params"]
5971
).initialize()
@@ -103,6 +115,7 @@ def predict_step(self, batch: Prompts, **kwargs) -> Dict:
103115
outputs = self.llm.generate(
104116
prompts,
105117
sampling_params=sampling_params,
118+
use_tqdm=False,
106119
)
107120

108121
result = {GENERATED_TEXTS_COL: [output.outputs[0].text for output in outputs]}

src/fmcore/framework/_algorithm.py

+1
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def create_hyperparams(cls, hyperparams: Optional[Dict] = None) -> Hyperparamete
230230
@classmethod
231231
def convert_params(cls, params: Dict) -> Dict:
232232
## Convert and validate parameters for the algorithm
233+
# print(f'params for {cls.class_name}=\n{params}')
233234
cls.set_default_param_values(params)
234235
## This allows us to create a new Algorithm instance without specifying `hyperparams`.
235236
## If it is specified, we will pick cls.Hyperparameters, which can be overridden by the subclass.

src/fmcore/framework/_chain/Chain.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,8 @@ def run(
291291
background: bool = False,
292292
tracker: Optional[Union[Tracker, Dict, str]] = None,
293293
notifier: Optional[Union[Notifier, Dict, str]] = None,
294-
store_step_inputs: bool = False,
295-
store_step_outputs: bool = False,
294+
store_step_inputs: bool = True,
295+
store_step_outputs: bool = True,
296296
after: Optional[ChainExecution] = None,
297297
after_wait: conint(ge=0) = 15,
298298
step_wait: confloat(ge=0.0) = 0.0,

src/fmcore/framework/_dataset.py

+4
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def concat(
8585
@model_validator(mode="before")
8686
@classmethod
8787
def _set_dataset_params(cls, params: Dict) -> Dict:
88+
if "data_schema" not in params:
89+
raise ValueError(
90+
f"Cannot create instance of class '{cls.class_name}' without passing `data_schema` parameter."
91+
)
8892
data_schema: Union[Schema, MLTypeSchema] = params["data_schema"]
8993
if isinstance(data_schema, dict):
9094
## We need to infer the schema:

src/fmcore/framework/_predictions.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
MLTypeSchema,
3737
TaskOrStr,
3838
)
39-
from fmcore.framework._task_mixins import InputOutputDataMixin, SchemaValidationError
4039
from fmcore.framework._dataset import Dataset
40+
from fmcore.framework._task_mixins import InputOutputDataMixin, SchemaValidationError
4141

4242
Predictions = "Predictions"
4343
Visualization = "Visualization"
@@ -69,6 +69,11 @@ def _pre_registration_hook(cls):
6969
@model_validator(mode="before")
7070
@classmethod
7171
def _set_predictions_params(cls, params: Dict) -> Dict:
72+
if "data_schema" not in params:
73+
raise ValueError(
74+
f"Cannot create instance of class '{cls.class_name}' without passing `data_schema` parameter."
75+
)
76+
7277
params["data_schema"]: Schema = Schema.of(params["data_schema"], schema_template=cls.schema_template)
7378
# data_schema: Union[Schema, MLTypeSchema] = params['data_schema']
7479
# if isinstance(data_schema, dict):

src/fmcore/framework/_task/text_generation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -752,12 +752,12 @@ def _create_predictions(self, batch: Prompts, predictions: Dict, **kwargs) -> Ne
752752

753753

754754
class LanguageModelTaskMixin(Algorithm, ABC):
755-
lm: Optional[Union[GenerativeLM, Any]] = None
755+
lm: Optional[Any] = None
756756
icl_dataset: Optional[Dataset] = None
757757
icl_sampler: Optional[ICLSampler] = None ## Will be not-None when icl_dataset is not-None.
758758

759759
class Hyperparameters(Algorithm.Hyperparameters):
760-
lm: Optional[Dict] ## Params for llm
760+
lm: Optional[Dict] = None ## Params for llm
761761
batch_size: Optional[conint(ge=1)] = 1 ## By default, predict 1 row at a time.
762762
prompt_template: constr(min_length=1)
763763
icl_template: Optional[constr(min_length=1)] = None

0 commit comments

Comments
 (0)