This repo contains an experimentation training workflow for multi-label emotion classification on the Ekman emotions dataset (thethinkmachine/ekman-emotions). The pipeline fine-tunes a transformer encoder either with full-parameter updates or with Parameter-Efficient Fine-Tuning (LoRA) adapters. Experiments can be tracked with Weights & Biases, evaluation metrics are computed via sklearn, and trained artefacts can be pushed directly to the Hugging Face Hub.
Main stuff:
train.py– orchestrates configuration loading, data preparation, model initialisation, training, evaluation, and hub upload.config.yaml– single source of truth for model, LoRA, training, data, logging, and hub settings.utils.py– helper utilities for run naming, hub identifier generation, config sanity checks, and dataset re-splitting.eval.py– definesbenchmark, the metric callback used by the trainer.
ekman-emotions/
├── config.yaml # Main experiment configuration
├── train.py # Training / evaluation entry point
├── eval.py # Metric computation helper
├── utils.py # Run-name, hub-id, sanity checks, dataset utilities
├── requirements.txt # Python dependencies
├── data/ # Optional scripts + prepared datasets
├── checkpoints/, logs/ # Trainer outputs (created at runtime)
└── notebooks/ # Exploratory work and reporting
All experiments are fully driven by config.yaml. Edit values there to change behaviour. No direct code modifications required.
entity,project: Identifiers for logging to Weights & Biases. Credentials must be available in the environment (e.g.,WANDB_API_KEY).
base_checkpoint: Hugging Face model ID used as the starting point.train.pyloads its tokenizer and sequence classification head (withignore_mismatched_sizes=Truefor label-count changes).
dataset_name: Hugging Face dataset ID used bydatasets.load_dataset.labels: Ordered emotion list for multi-label classification.- Optional keys mirroring the legacy
datasetblock can also live here:resplit: Whentrue,train.pywill callutils.resplit_dataset.custom_split_ratio: Colon-delimited ratios (e.g."0.8:0.1:0.1").shuffle_before_resplit: Shuffle before re-splitting.random_seed: Controls deterministic shuffles.
use_lora: Toggles adapter-based fine-tuning.target_modules: List of module names to receive LoRA adapters.utils.validate_config_sanitycurrently enforces this list even ifuse_lora=False—keep it populated or adjust the sanity check.r,lora_alpha,lora_dropout,bias: Fed intopeft.LoraConfigwhen adapters are enabled. The helper functions also bakerandalphainto run names and hub IDs for traceability.
Mapped directly into transformers.TrainingArguments (see the call in train.py). Highlights:
num_train_epochs,learning_rate,weight_decay,adam_epsilon,max_grad_norm: Core optimisation knobs.warmup_ratio,warmup_steps,lr_scheduler_type: Scheduler behaviour.auto_find_batch_size: Whentrue, the trainer will try the providedper_device_*_batch_size(64 by default) and automatically halve on Out-Of-Memory (OOM) during an internal warm-up step. Start with an upper bound you believe your GPU can handle—Trainer handles the back-off.optim: Optimiser choice. Defaults toadamw_torch; alternatives such asadamw_bnb_8bitrequire matching dependencies (bitsandbytes).gradient_accumulation_steps,gradient_checkpointing,group_by_length: Memory/performance trade-offs.fp16,bf16,tf32: Mixed-precision toggles. Only enable one offp16orbf16.eval_strategy/save_strategy: Evaluation/checkpoint cadence (supports"no","steps","epoch"). When set to"steps", ensure the matchingeval_steps/save_stepsare > 0. Ifload_best_model_at_end=True, strategies must match and cannot be"no".per_device_train_batch_size,per_device_eval_batch_size: Max values tried before auto tuning. Also reusable whenauto_find_batch_size=False.push_to_hub,hub_private_repo,hub_strategy,hub_model_id: Configure Hub uploads; blankhub_model_idlets utilities autogenerate a descriptive ID.run_name: Optional override for W&B run names. Defaults to the auto-generated string fromutils.make_run_name.
logging_steps,logging_strategy,report_to: Logging cadence and destinations.report_totypically includes"wandb"when W&B tracking is required.
-
Configuration load & validation
- Loads
config.yaml, runsutils.validate_config_sanity. The current sanity check ensures LoRA settings are well-formed and basic training parameters are positive.
- Loads
-
Experiment tracking
- Initialises Weights & Biases via
wandb.init, with a descriptive run name fromutils.make_run_name(encodes base model, LoRA/full-finetune, LR, epochs, weight decay, warmup, timestamp).
- Initialises Weights & Biases via
-
Tokenizer & model
- Loads tokenizer + sequence classification head using
AutoTokenizerandAutoModelForSequenceClassification. The classification head is configured forproblem_type="multi_label_classification"and inherits label mappings.
- Loads tokenizer + sequence classification head using
-
Dataset ingestion
- Pulls the dataset referenced in
config['data']['dataset_name']viadatasets.load_dataset. - Applies tokenisation and optional re-splitting (
utils.resplit_dataset). - Uses
DataCollatorWithPaddingfor dynamic padding.
- Pulls the dataset referenced in
-
LoRA adapters (optional)
- When
lora.use_lora=True, wraps the base model with PEFT’sget_peft_model. Adapter hyperparameters mirror YAML values. Console logs confirm the configuration.
- When
-
TrainingArguments assembly
- Builds
TrainingArgumentsfrom thetrainingsection, with file-system paths rooted under the project directory (checkpoints/,logs/).save_safetensors=Trueenforces safetensor checkpoints.
- Builds
-
Trainer setup
- Instantiates
transformers.Trainerwith:benchmarkmetric function fromeval.py(computes macro/micro F1, per-label F1s, and Hamming loss).- Tokeniser + data collator for consistent padding.
- Instantiates
-
Train / evaluate / hub upload
- Runs
.train()and.evaluate()on the held-out test split. - Pushes model + tokenizer to the Hugging Face Hub using the descriptive ID from
utils.make_hub_id. - Finishes the W&B run with
wandb.finish().
- Runs
The benchmark function receives raw logits and one-hot label vectors, applies a sigmoid → 0.5 threshold, and returns:
f1_macro,f1_microhamming_lossf1_<label>for each Ekman emotion
Use these keys for metric_for_best_model in config.yaml.
make_run_name(config): Generates consistent W&B/HF run names embedding base model, adaptation mode, LR, epochs, weight decay, warmup ratio, timestamp, and (if applicable) LoRAr/alpha.make_hub_id(config): Mirrors the above to create unique, self-describing Hub repo IDs.validate_config_sanity(cfg): Lightweight assertions for early config errors (positive epochs, matching eval/save strategies, etc.). Adjust as your config schema evolves—the current implementation assumes LoRA fields are always present.resplit_dataset(...): Concatenates existing splits and re-divides according to custom ratios, with optional shuffle.
-
Create environment
python -m venv .venv .\.venv\Scripts\Activate.ps1 pip install -r requirements.txt
-
Prepare credentials
- Hugging Face Hub token (
HUGGINGFACE_HUB_TOKEN) for pushing models. - W&B API key (
WANDB_API_KEY) if logging is enabled. - Store them in
.env(loaded viapython-dotenv) or export in the shell.
- Hugging Face Hub token (
-
Edit
config.yaml- Update dataset, training, logging, or LoRA parameters as needed.
- Ensure
labelsmatches the dataset’s column order;problem_typeis multi-label.
-
Launch training
python train.py
-
Outputs
- Checkpoints:
checkpoints/ - Logs (for TensorBoard/W&B syncing):
logs/ - Final evaluation metrics: printed to stdout and available via
Trainerlogs. - Model artefacts: Pushed to the Hub if enabled; run name encodes settings for reproducibility.
- Checkpoints:
- Weights & Biases: Controlled by the
logging.report_tolist andwandbsection. Run names are auto-generated unless overridden (training.run_name). - Gradient accumulation & auto batch sizing: Large effective batch sizes can be achieved through
gradient_accumulation_steps;auto_find_batch_sizehalves the per-device batch on OOM and retries until success. - Hub pushes:
hub_strategycontrols cadence ("end","every_save", etc.). Make sure your token has write access; private repos are supported whenhub_private_repo=True.
- LoRA checks when disabled: The sanity validator still requires
target_modulesto be populated. Either keep the list filled or relax the check if you foresee pure full-finetune runs. - Adapter target names: Ensure entries in
lora.target_modulesmatch modules inside the chosen transformer (inspectmodel.named_modules()as needed).
- Author & maintainer: Shreyan Chaubey – Kaggle profile