Skip to content

Build model pipeline #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,5 @@ node_modules/
.model
checkpoint-*
runs

_model
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ AI writing assistant (prototype).

3. Install the extension in VSCode

code --install-extension extension/rht-text-generator/rht-text-generator-0.0.2.vsix
code --install-extension extension/rht-text-generator/rht-text-generator-0.0.3.vsix

## Retrain the model

Expand All @@ -20,7 +20,7 @@ AI writing assistant (prototype).

2. Train:

./train


## Rebuild the model server image

Expand Down
31 changes: 17 additions & 14 deletions build_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,28 @@
VALIDATION_PATH = "data/dataset_validation.txt"


def parse_sections(f):
sections = []
def parse_section(f):
clean_lines = []
for line in f:
line = line.rstrip()

line = line.rstrip().replace(":gls_prefix:", "")

if (line.startswith("//")
or line.startswith("ifndef")
or line.startswith(":experiment")):
continue

if re.match(r"^=+ \w+", line):
sections.append(line)
else:
try:
sections[-1] += "\n" + line
except IndexError:
pass
clean_lines.append(line)

# if re.match(r"^=+ \w+", line):
# sections.append(line)
# else:
# try:
# sections[-1] += "\n" + line
# except IndexError:
# pass

return sections
return "\n".join(clean_lines)


sections = []
Expand All @@ -51,14 +54,14 @@ def parse_sections(f):
filepath = os.path.join(dirpath, f)
print(filepath)
with open(filepath, "r") as f:
sections += parse_sections(f)
sections.append(parse_section(f))

random.Random(42).shuffle(sections)
num_sections = len(sections)
train_size = int(num_sections * 0.8)

with open(TRAIN_PATH, "w") as f:
f.write("\n".join(sections[:train_size]))
f.write("\n<|endoftext|>\n".join(sections[:train_size]))

with open(VALIDATION_PATH, "w") as f:
f.write("\n".join(sections[train_size:]))
f.write("\n<|endoftext|>\n".join(sections[train_size:]))
5 changes: 3 additions & 2 deletions extension/rht-text-generator/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"publisher": "red-hat-training",
"description": "VSCode RHT extension for writting suggestions",
"repository": "https://github.com/RedHatTraining/ai-text-generation",
"version": "0.0.2",
"version": "0.0.3",
"engines": {
"vscode": "^1.55.0"
},
Expand All @@ -24,7 +24,8 @@
"watch": "tsc -watch -p ./",
"pretest": "npm run compile && npm run lint",
"lint": "eslint src --ext ts",
"test": "node ./out/test/runTest.js"
"test": "node ./out/test/runTest.js",
"publish": "npx vsce package"
},
"devDependencies": {
"@types/vscode": "^1.55.0",
Expand Down
Binary file not shown.
8 changes: 5 additions & 3 deletions extension/rht-text-generator/src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ function getText(document: vscode.TextDocument, position: vscode.Position): stri

async function generateSuggestions(line: string, predictionLength: number, server: string) {
let suggestions: string[] = [];
const body = { text: line, length: predictionLength };

try {
const response = await Axios.get<[string]>(
`http://${server}/?text=${line}&length=${predictionLength}`
);
console.log(line);
console.log(predictionLength);
const response = await Axios.post<[string]>(`http://${server}/`, body);
suggestions = response.data;
} catch (error) {
if (error.response) {
Expand Down
1 change: 1 addition & 0 deletions model_serving/cpu/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
MODEL_PATH=/path/to/model
12 changes: 12 additions & 0 deletions model_serving/cpu/Containerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
FROM quay.io/modh/runtime-images:runtime-minimal-ubi9-python-3.9-20240614

COPY requirements.txt .
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir -r requirements.txt

COPY runtime.py .

ENV MODEL_PATH=/mnt/models/
EXPOSE 8000

CMD ["fastapi", "run", "runtime.py"]
7 changes: 7 additions & 0 deletions model_serving/cpu/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
--extra-index-url https://download.pytorch.org/whl/cpu
fastapi==0.111.0
pydantic==2.7.4
python-dotenv==1.0.1
numpy==1.24.4
transformers==4.39.3
torch==2.1.2
69 changes: 69 additions & 0 deletions model_serving/cpu/runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
from typing import Union
from fastapi import FastAPI
from dotenv import load_dotenv
from pydantic import BaseModel
from transformers import pipeline, set_seed, GPT2Tokenizer


load_dotenv()


MODEL_PATH = os.getenv("MODEL_PATH")
if not MODEL_PATH:
print("MODEL_PATH env variable is required")
exit(1)


set_seed(42)
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH)
generator = pipeline("text-generation", model=MODEL_PATH)
app = FastAPI()


@app.get("/")
async def predict_get(text: str, length: int = 3, no_top: bool = False):
num_predicted_tokens = length
no_topp = no_top

return serve_predict_request(text, num_predicted_tokens, no_topp)


class InferencePostRequest(BaseModel):
text: str
length: Union[str, None] = 3


@app.post("/")
async def predict_post(body: InferencePostRequest):
text = body.text
num_predicted_tokens = body.length

return serve_predict_request(text, num_predicted_tokens)


def serve_predict_request(text: str, num_predicted_tokens: int, no_topp=False):
tokens = tokenizer(text, return_length=True)
num_tokens = tokens["length"]
max_length = num_tokens + num_predicted_tokens

kargs = {"do_sample": True, "top_k": max_length, "top_p": 0.92}

if no_topp:
kargs = {}

# For info about the args: https://huggingface.co/blog/how-to-generate
predictions = generator(
text,
max_length=max_length,
num_return_sequences=5,
output_scores=True,
return_full_text=False,
**kargs
)

result = [p["generated_text"] for p in predictions]

print("Predictions:", result)

return result
19 changes: 19 additions & 0 deletions model_serving/cpu/serving-runtime.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
apiVersion: serving.kserve.io/v1alpha1
kind: ServingRuntime
labels:
opendatahub.io/dashboard: "true"
metadata:
annotations:
openshift.io/display-name: RHT text generation runtime
name: rht-text-generator-runtime
spec:
containers:
- image: quay.io/jramcast/rht-text-generator:model-serving-runtime-cpu-0.1.1
name: kserve-container
ports:
- containerPort: 8000
protocol: TCP
multiModel: false
supportedModelFormats:
- autoSelect: true
name: pytorch
2 changes: 2 additions & 0 deletions pipelines/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
projects.csv
*.tar.gz
6 changes: 6 additions & 0 deletions pipelines/build_model_server/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
S3_MODEL_PATH=path/to/model/.model
LOCAL_MODEL_PATH=.model
AWS_S3_ENDPOINT=https://minio-api-rht-text-generation.apps.rhods-internal.61tk.p1.openshiftapps.com
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
AWS_S3_BUCKET=ai-text-generation
84 changes: 84 additions & 0 deletions pipelines/build_model_server/download_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!/usr/bin/env python3

import os
import boto3
from pathlib import Path
from dotenv import load_dotenv


load_dotenv()


def download_s3_folder(
bucket_name: str,
s3_dir: str,
local_dir: Path,
endpoint: str,
key_id: str,
access_key: str,
):
# Create an S3 client with the provided connection parameters
s3 = boto3.client(
"s3",
endpoint_url=endpoint,
aws_access_key_id=key_id,
aws_secret_access_key=access_key,
)

# Ensure the local directory exists
if not os.path.exists(local_dir):
os.makedirs(local_dir)

# Get the list of objects in the S3 folder
paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=s3_dir)

for page in pages:
if "Contents" in page:
for obj in page["Contents"]:
key = obj["Key"]
# Remove the folder prefix to get the relative path
relative_path = os.path.relpath(key, s3_dir)
local_file_path = os.path.join(local_dir, relative_path)

# Ensure the local directory exists for the file
local_dir_path = os.path.dirname(local_file_path)
if not os.path.exists(local_dir_path):
os.makedirs(local_dir_path)

# Download the file
s3.download_file(bucket_name, key, local_file_path)
print(f"Downloaded {key} to {local_file_path}")


if __name__ == "__main__":
s3_dir_path = os.getenv("S3_MODEL_PATH")
local_dir_path = Path(os.getenv("LOCAL_MODEL_PATH"))

s3_endpoint_url = os.environ.get("AWS_S3_ENDPOINT")
s3_access_key = os.environ.get("AWS_ACCESS_KEY_ID")
s3_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
s3_bucket_name = os.environ.get("AWS_S3_BUCKET")

if not all([s3_endpoint_url, s3_access_key, s3_secret_key, s3_bucket_name]):
print(
"Please set all the required environment variables: "
"AWS_S3_ENDPOINT, AWS_ACCESS_KEY_ID, "
"AWS_SECRET_ACCESS_KEY, AWS_S3_BUCKET"
)
else:
print(
f"Downloading {s3_dir_path} dir "
f"from bucket {s3_bucket_name} "
f"at {s3_endpoint_url} "
f"to {local_dir_path}"
)

download_s3_folder(
s3_bucket_name,
s3_dir_path,
local_dir_path,
s3_endpoint_url,
s3_access_key,
s3_secret_key,
)
2 changes: 2 additions & 0 deletions pipelines/build_model_server/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
boto3==1.34.126
python-dotenv==1.0.1
Loading