Skip to content

smasoudrezvani/simple_news_classification_llm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

News Headline Classification — Fine-Tuning DistilBERT and Deploying to AWS

Python PyTorch HuggingFace AWS

An end-to-end MLOps project: custom fine-tuning of DistilBERT on the UCI News Aggregator dataset for four-class headline classification, then deploying the model as a live serverless REST API on AWS — >99% confidence on held-out examples.


Architecture

Client (Postman / curl)
        │
        │  HTTP POST  {"query": {"headline": "..."}}
        ▼
┌──────────────────────────────┐
│      AWS API Gateway         │  REST API — POST /dev/invoke-model
│      Region: eu-north-1      │
└──────────────┬───────────────┘
               │  Lambda Proxy Integration
               ▼
┌──────────────────────────────┐
│    AWS Lambda                │  llm-endpoint-invoke-function
│    Python runtime            │  Parses body → wraps as {"inputs": headline}
└──────────────┬───────────────┘
               │  boto3  sagemaker-runtime.invoke_endpoint()
               ▼
┌─────────────────────────────────────────────────────┐
│  Amazon SageMaker Endpoint                          │
│  multiclass-text-classification-endpointv1          │  ml.m5.xlarge
│  HuggingFace Inference DLC (PyTorch 1.7)            │
└──────────────┬──────────────────────────────────────┘
               │  Loads model artifacts
               ▼
┌──────────────────────────────┐
│        Amazon S3             │  Training data + model.tar.gz
└──────────────────────────────┘
               ▲
               │  Writes output artifacts
┌──────────────────────────────┐
│  SageMaker Training Job      │  ml.p2.xlarge (NVIDIA K80 GPU)
│  HuggingFace Estimator       │  entry_point: script.py
└──────────────────────────────┘

Model

Detail Value
Base model distilbert-base-uncased
Task 4-class text classification
Classes Business · Science & Technology · Entertainment · Health
Custom head Linear(768→768) → ReLU → Dropout(0.3)Linear(768→4)
Pooling [CLS] token hidden state (index 0 of last hidden layer)
Optimizer Adam, lr = 1e-5
Loss CrossEntropyLoss
Epochs 2
Batch size 4 (train) / 2 (validation)
Max sequence length 512 tokens
Train / test split 80 / 20, random_state=200
Training hardware ml.p2.xlarge (GPU)

The model extracts the [CLS] token embedding from DistilBERT's final layer and passes it through a custom classification head. The entire network is fine-tuned end-to-end — both the DistilBERT encoder weights and the head.


AWS Infrastructure

IAM Roles and Policies

Two separate IAM roles enforce least-privilege separation between training/hosting and inference invocation.

SageMaker Execution Role (attached to the SageMaker Studio domain user profile)

  • AmazonSageMakerFullAccess — allows launching training jobs, creating/updating endpoints, and pulling HuggingFace DLC images from Amazon ECR
  • AmazonS3FullAccess — scoped to the project S3 bucket; used by the training job container to read input data and write model artifacts, and by the endpoint container to load model.tar.gz at startup

Lambda Execution Role (llm-endpoint-invoke-function)

  • AWSLambdaBasicExecutionRole — grants logs:CreateLogGroup, logs:CreateLogStream, and logs:PutLogEvents to CloudWatch Logs
  • Custom inline policy: sagemaker:InvokeEndpoint scoped to the specific endpoint ARN — Lambda can invoke the endpoint but cannot create, update, or delete any SageMaker resource

API Gateway → Lambda

  • API Gateway is granted lambda:InvokeFunction via a resource-based policy on the Lambda function, added automatically when the API Gateway trigger is configured in the Lambda console.

S3 Bucket Layout

s3://your-bucket/
├── data/
│   └── uci-news-aggregator.csv        # Tab-separated: ID, TITLE, URL, PUBLISHER,
│                                      # CATEGORY, STORY, HOSTNAME, TIMESTAMP
└── output/
    └── <sagemaker-training-job-name>/
        └── output/
            └── model.tar.gz           # pytorch_distilbert_news.bin + tokenizer vocab

SageMaker Training Job

Launched from train.ipynb using the SageMaker HuggingFace estimator. SageMaker provisions a GPU instance, pulls the managed HuggingFace DLC, injects script.py as the entry point, and handles all container lifecycle management.

from sagemaker.huggingface import HuggingFace

estimator = HuggingFace(
    entry_point='script.py',
    instance_type='ml.p2.xlarge',   # NVIDIA K80 GPU
    transformers_version='4.6',
    pytorch_version='1.8',
    py_version='py36',
    role=role,
    hyperparameters={
        'epochs': 2,
        'train_batch_size': 4,
        'valid_batch_size': 2,
        'learning_rate': 1e-5,
    }
)
estimator.fit({'training': 's3://your-bucket/data/'})

SageMaker injects the output path via the SM_MODEL_DIR environment variable. script.py saves pytorch_distilbert_news.bin and the tokenizer vocabulary there; SageMaker then packages the directory as model.tar.gz and uploads it to S3 automatically.


SageMaker Endpoint

Deployed from deploy.ipynb. The HuggingFace Inference DLC handles HTTP serving; a custom inference.py (bundled inside model.tar.gz) manages model loading and response formatting.

from sagemaker.huggingface import HuggingFaceModel

model = HuggingFaceModel(
    model_data='s3://your-bucket/output/<job-name>/output/model.tar.gz',
    transformers_version='4.6',
    pytorch_version='1.7',
    py_version='py36',
    entry_point='inference.py',
    role=role,
)

predictor = model.deploy(
    initial_instance_count=1,
    instance_type='ml.m5.xlarge',   # CPU — cost-effective for synchronous inference
    endpoint_name='multiclass-text-classification-endpointv1',
)

SageMaker endpoint in-service status


Lambda Function (llm-endpoint-invoke-function)

The function acts as a translation layer between API Gateway's HTTP event format and the SageMaker endpoint's binary invocation protocol:

  1. Receives the raw API Gateway proxy event
  2. Parses event['body'] (a JSON string) and extracts body['query']['headline']
  3. Re-packages the headline as {"inputs": headline} — the schema expected by the HuggingFace Inference DLC
  4. Calls boto3.client('sagemaker-runtime').invoke_endpoint() with ContentType: application/json
  5. Reads and decodes the binary response['Body'], then returns it as an HTTP 200 JSON response

The function runs in ~1 second average (dominated by SageMaker inference latency), with 0 throttles across all recorded invocations.

CloudWatch metrics (live test run):

Metric Value
Total invocations 7
Average duration 1,035 ms
Min duration 3.04 ms
Max duration 2,819 ms
Throttles 0
Max concurrent executions 1

Lambda CloudWatch monitoring — invocations, duration, error rate, throttles Lambda function overview — API Gateway trigger wired to llm-endpoint-invoke-function


API Gateway

A REST API with a POST /dev/invoke-model resource connected to llm-endpoint-invoke-function via Lambda Proxy Integration. The full HTTP event (method, headers, body) is forwarded as-is to Lambda, so the function receives the raw request without any transformation by API Gateway.

Request format:

{
  "query": {
    "headline": "your news headline here"
  }
}

Response format:

[
  {
    "predicted_label": "Business",
    "probabilities": [[0.9959, 0.0023, 0.0015, 0.0003]]
  }
]

Results

The model returns the predicted class label and the full softmax probability distribution over all four categories.

Live API call via Postman:

Input:   "Elon Musk tweeted about a controversial matters again"
Output:  {"predicted_label": "Science", "probabilities": [[0.0032, 0.9965, 0.0002, 0.0001]]}
         → Science & Technology — 99.65% confidence
         Status: 200 OK  |  Latency: 3.83 s end-to-end

Deployment notebook smoke test:

Input:   "The stock market hit an all time low"
Output:  {"predicted_label": "Business", "probabilities": [[0.9959, 0.0023, 0.0015, 0.0003]]}
         → Business — 99.59% confidence

Live POST request and response in Postman — 200 OK, 3.83 s


Repository

File Description
script.py SageMaker training entry point — NewsDataset (tokenization pipeline), DistilBERTClass (custom nn.Module), training loop, validation loop, model saving to SM_MODEL_DIR
train.ipynb SageMaker HuggingFace estimator setup and training job launch
deploy.ipynb SageMaker endpoint deployment and smoke-test inference
lambda_function.py AWS Lambda handler — translates API Gateway proxy events into SageMaker invoke_endpoint calls via Boto3
load-test.py Batch test data generator — serialises sample headlines to individual JSON files and packages them into a .tar.gz for SageMaker Batch Transform jobs

Skills Demonstrated

Machine Learning / NLP

  • Transfer learning with distilbert-base-uncased; custom torch.nn.Module classification head
  • PyTorch Dataset / DataLoader pipeline with HuggingFace tokenizer (padding, truncation, attention masks)
  • Full fine-tuning loop: forward pass, CrossEntropyLoss, Adam optimizer, per-step and per-epoch metrics

MLOps / AWS

  • SageMaker managed training jobs with the HuggingFace DLC on a GPU instance (ml.p2.xlarge)
  • SageMaker endpoint deployment and lifecycle management (ml.m5.xlarge)
  • IAM least-privilege design: separate execution roles for training vs. Lambda inference invocation
  • S3 artifact lifecycle: training data → SM_MODEL_DIRmodel.tar.gz → endpoint

Serverless Architecture

  • AWS Lambda with Boto3 for cross-service integration
  • API Gateway REST API with Lambda Proxy Integration
  • CloudWatch observability: invocation counts, duration percentiles, throttle tracking

Python / Tools

  • PyTorch, Hugging Face Transformers, Boto3, Pandas
  • SageMaker Python SDK, Jupyter, Postman

License

This project is licensed under the MIT License.


About This README

This documentation was written by Claude Code, Anthropic's AI coding assistant.

Ironically, a large language model wrote the documentation for a project about fine-tuning and deploying a large language model. Somewhere in that loop there's a thesis topic waiting to happen.

In the process, Claude Code also: spotted a main()p syntax error you probably would have caught eventually, renamed four image files with spaces in their names (a war crime against URLs), corrected "cloakwatch" to "CloudWatch", and rewrote the entire README — all without once asking for its own SageMaker endpoint.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors