Skip to content

Feed RL Training

Feed RL Training #21

name: Feed RL Training
# Trigger options:
# 1. On schedule (daily training at 2 AM UTC)
# 2. Via repository_dispatch (from debug endpoint)
# 3. Manual workflow_dispatch (testing from GitHub UI)
on:
schedule:
# Daily at 2 AM UTC
- cron: '0 2 * * *'
repository_dispatch:
types: [trigger-training]
workflow_dispatch:
inputs:
batch_id:
description: 'Training batch ID (optional, auto-generated if not provided)'
required: false
type: string
window_id:
description: 'Window ID to train (optional, auto-detected if not provided)'
required: false
type: string
force:
description: 'Force training even if not ready (for testing)'
required: false
type: boolean
default: false
base_model:
description: 'Base model to use (e.g., Qwen/Qwen2.5-7B-Instruct)'
required: false
type: string
env:
PYTHON_VERSION: '3.11'
# Prevent concurrent training runs
concurrency:
group: training-pipeline
cancel-in-progress: false # Wait for current run to finish
permissions:
contents: read
jobs:
train:
if: ${{ vars.ENABLE_RL_TRAINING == 'true' }}
name: Train RL Model
runs-on: ubuntu-latest
timeout-minutes: 360 # 6 hours max
steps:
- name: Checkout code
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5
- name: Setup Python
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405
with:
python-version: ${{ env.PYTHON_VERSION }}
cache: 'pip'
cache-dependency-path: 'packages/training/scripts/rl/requirements.txt'
- name: Install dependencies
working-directory: packages/training/scripts/rl
run: |
pip install --upgrade pip
pip install -r requirements.txt
pip install -e .
- name: Verify installation
run: |
python --version
pip list | grep -E "(art|asyncpg)"
- name: Check training readiness
id: readiness
env:
DATABASE_URL: ${{ secrets.DATABASE_URL }}
run: |
python -c "
import asyncio
import asyncpg
import os
import sys
async def check():
try:
pool = await asyncpg.create_pool(
os.getenv('DATABASE_URL'),
min_size=1,
max_size=2,
timeout=30
)
# Count scored trajectories ready for training
count = await pool.fetchval('''
SELECT COUNT(*) FROM trajectories
WHERE \"isTrainingData\" = true
AND \"usedInTraining\" = false
AND \"aiJudgeReward\" IS NOT NULL
AND \"stepsJson\" IS NOT NULL
AND \"stepsJson\"::text != 'null'
AND \"stepsJson\"::text != '[]'
''')
print(f'✅ Database connected')
print(f'📊 Trajectories ready for training: {count}')
# Need 100 bundles minimum
min_required = 100
ready = count >= min_required
if ready:
print(f'✅ READY: {count} >= {min_required}')
else:
print(f'⏳ NOT READY: {count} < {min_required} (need {min_required - count} more)')
with open(os.getenv('GITHUB_OUTPUT'), 'a') as f:
f.write(f'ready={str(ready).lower()}\\n')
f.write(f'count={count}\\n')
await pool.close()
except Exception as e:
print(f'❌ Error checking readiness: {e}', file=sys.stderr)
with open(os.getenv('GITHUB_OUTPUT'), 'a') as f:
f.write('ready=false\\n')
f.write('count=0\\n')
sys.exit(1)
asyncio.run(check())
"
- name: Get batch info
id: batch
env:
# From workflow_dispatch inputs:
BATCH_ID_INPUT: ${{ inputs.batch_id || '' }}
WINDOW_ID_INPUT: ${{ inputs.window_id || '' }}
FORCE_INPUT: ${{ inputs.force || 'false' }}
BASE_MODEL_INPUT: ${{ inputs.base_model || '' }}
# From repository_dispatch payload:
BATCH_ID_PAYLOAD: ${{ github.event.client_payload.batch_id || '' }}
WINDOW_ID_PAYLOAD: ${{ github.event.client_payload.window_id || '' }}
FORCE_PAYLOAD: ${{ github.event.client_payload.force || 'false' }}
DATABASE_URL: ${{ secrets.DATABASE_URL }}
run: |
# Determine batch_id
if [ -n "$BATCH_ID_PAYLOAD" ]; then
echo "batch_id=$BATCH_ID_PAYLOAD" >> $GITHUB_OUTPUT
elif [ -n "$BATCH_ID_INPUT" ]; then
echo "batch_id=$BATCH_ID_INPUT" >> $GITHUB_OUTPUT
else
echo "batch_id=batch-$(date +%s)" >> $GITHUB_OUTPUT
fi
# Determine window_id
if [ -n "$WINDOW_ID_PAYLOAD" ]; then
echo "window_id=$WINDOW_ID_PAYLOAD" >> $GITHUB_OUTPUT
elif [ -n "$WINDOW_ID_INPUT" ]; then
echo "window_id=$WINDOW_ID_INPUT" >> $GITHUB_OUTPUT
else
WINDOW=$(date -u +"%Y-%m-%dT%H:00")
echo "window_id=$WINDOW" >> $GITHUB_OUTPUT
fi
# Determine force flag
if [ "$FORCE_PAYLOAD" = "true" ] || [ "$FORCE_INPUT" = "true" ]; then
echo "force=true" >> $GITHUB_OUTPUT
else
echo "force=false" >> $GITHUB_OUTPUT
fi
# Determine base model
if [ -n "$BASE_MODEL_INPUT" ]; then
echo "base_model=$BASE_MODEL_INPUT" >> $GITHUB_OUTPUT
else
echo "base_model=Qwen/Qwen2.5-7B-Instruct" >> $GITHUB_OUTPUT
fi
# Generate model version
MODEL_VERSION=$(python -c "
import asyncio
import asyncpg
import os
import sys
async def get_version():
try:
pool = await asyncpg.create_pool(os.getenv('DATABASE_URL'), timeout=10)
latest = await pool.fetchval('''
SELECT version FROM trained_models
WHERE status IN ('ready', 'deployed')
ORDER BY \"createdAt\" DESC
LIMIT 1
''')
if latest:
parts = latest.strip('v').split('.')
patch = int(parts[2]) + 1
version = f'v{parts[0]}.{parts[1]}.{patch}'
else:
version = 'v1.0.0'
print(version)
await pool.close()
except Exception as e:
import time
version = f'v1.0.{int(time.time()) % 10000}'
print(version)
asyncio.run(get_version())
" 2>/dev/null)
echo "model_version=$MODEL_VERSION" >> $GITHUB_OUTPUT
echo "source=github_cron" >> $GITHUB_OUTPUT
- name: Skip if not ready (unless forced)
if: steps.readiness.outputs.ready != 'true' && steps.batch.outputs.force != 'true'
run: |
echo "⏭️ Not ready for training and force=false"
echo "Trajectories: ${{ steps.readiness.outputs.count }}"
echo "Required: 100 (minimum bundles)"
exit 0
- name: Update batch status to training
if: steps.readiness.outputs.ready == 'true' || steps.batch.outputs.force == 'true'
env:
DATABASE_URL: ${{ secrets.DATABASE_URL }}
BATCH_ID: ${{ steps.batch.outputs.batch_id }}
run: |
python -c "
import asyncio
import asyncpg
import os
async def update():
pool = await asyncpg.create_pool(os.getenv('DATABASE_URL'))
batch_id = os.getenv('BATCH_ID')
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute('''
INSERT INTO training_batches (
\"batchId\", id, status, \"startedAt\", \"createdAt\"
) VALUES (
\$1, \$1, 'training', NOW(), NOW()
)
ON CONFLICT (\"batchId\")
DO UPDATE SET status = 'training', \"startedAt\" = NOW()
''', batch_id)
print(f'✅ Batch {batch_id} status: training')
await pool.close()
asyncio.run(update())
"
- name: Select base model
id: model_selection
if: steps.readiness.outputs.ready == 'true' || steps.batch.outputs.force == 'true'
env:
DATABASE_URL: ${{ secrets.DATABASE_URL }}
BASE_MODEL_OVERRIDE: ${{ steps.batch.outputs.base_model }}
run: |
python -c "
import asyncio
import asyncpg
import os
import sys
async def select_model():
try:
pool = await asyncpg.create_pool(os.getenv('DATABASE_URL'))
base_model_override = os.getenv('BASE_MODEL_OVERRIDE', '')
# Count training bundles
bundle_count = await pool.fetchval('''
SELECT COUNT(*) FROM trajectories
WHERE \"isTrainingData\" = true
AND \"usedInTraining\" = false
AND \"aiJudgeReward\" IS NOT NULL
AND \"stepsJson\" IS NOT NULL
AND \"stepsJson\"::text != 'null'
AND \"stepsJson\"::text != '[]'
''')
print(f'📊 Bundle count: {bundle_count}')
# Use override if provided, otherwise use default
if base_model_override:
base_model = base_model_override
strategy = 'override'
else:
base_model = 'Qwen/Qwen2.5-7B-Instruct'
strategy = 'default'
print(f'📦 Selected model: {base_model}')
print(f'📋 Strategy: {strategy}')
with open(os.getenv('GITHUB_OUTPUT'), 'a') as f:
f.write(f'base_model={base_model}\\n')
f.write(f'strategy={strategy}\\n')
f.write(f'bundle_count={bundle_count}\\n')
await pool.close()
except Exception as e:
print(f'❌ Model selection failed: {e}', file=sys.stderr)
with open(os.getenv('GITHUB_OUTPUT'), 'a') as f:
f.write('base_model=Qwen/Qwen2.5-7B-Instruct\\n')
f.write('strategy=fallback\\n')
sys.exit(1)
asyncio.run(select_model())
"
- name: Run RL Training
id: training
if: steps.readiness.outputs.ready == 'true' || steps.batch.outputs.force == 'true'
env:
DATABASE_URL: ${{ secrets.DATABASE_URL }}
BATCH_ID: ${{ steps.batch.outputs.batch_id }}
WINDOW_ID: ${{ steps.batch.outputs.window_id }}
MODEL_VERSION: ${{ steps.batch.outputs.model_version }}
BASE_MODEL: ${{ steps.model_selection.outputs.base_model }}
MODE: single
MAX_EXAMPLES: "2000"
MAX_STEPS_PER_TRAJECTORY: "20"
MAX_SEQ_LENGTH: "8192"
working-directory: packages/training/scripts/rl
run: |
echo "🚀 Starting training"
echo "Batch ID: $BATCH_ID"
echo "Window ID: $WINDOW_ID"
echo "Model Version: $MODEL_VERSION"
echo "Strategy: ${{ steps.model_selection.outputs.strategy }}"
echo "Base Model: $BASE_MODEL"
echo "Trajectories available: ${{ steps.readiness.outputs.count }}"
echo "Bundle count: ${{ steps.model_selection.outputs.bundle_count }}"
# Run trainer
python src/training/feed_trainer.py
- name: Update batch status to completed
if: success() && (steps.readiness.outputs.ready == 'true' || steps.batch.outputs.force == 'true')
env:
DATABASE_URL: ${{ secrets.DATABASE_URL }}
BATCH_ID: ${{ steps.batch.outputs.batch_id }}
run: |
python -c "
import asyncio
import asyncpg
import os
async def update():
pool = await asyncpg.create_pool(os.getenv('DATABASE_URL'))
batch_id = os.getenv('BATCH_ID')
await pool.execute('''
UPDATE training_batches
SET status = 'completed', \"completedAt\" = NOW()
WHERE \"batchId\" = \$1
''', batch_id)
print(f'✅ Batch {batch_id} completed')
await pool.close()
asyncio.run(update())
"
- name: Update batch status to failed
if: failure() && (steps.readiness.outputs.ready == 'true' || steps.batch.outputs.force == 'true')
env:
DATABASE_URL: ${{ secrets.DATABASE_URL }}
BATCH_ID: ${{ steps.batch.outputs.batch_id }}
run: |
python -c "
import asyncio
import asyncpg
import os
async def update():
pool = await asyncpg.create_pool(os.getenv('DATABASE_URL'))
batch_id = os.getenv('BATCH_ID')
await pool.execute('''
UPDATE training_batches
SET status = 'failed', error = 'GitHub Actions workflow failed'
WHERE \"batchId\" = \$1
''', batch_id)
print(f'❌ Batch {batch_id} failed')
await pool.close()
asyncio.run(update())
"
- name: Upload training logs
if: always()
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a
with:
name: training-logs-${{ steps.batch.outputs.batch_id }}
path: |
packages/training/scripts/rl/logs/
packages/training/scripts/rl/*.log
retention-days: 7
- name: Report status
if: always()
run: |
echo "Training Status: ${{ job.status }}"
echo "Batch ID: ${{ steps.batch.outputs.batch_id }}"
echo "Window ID: ${{ steps.batch.outputs.window_id }}"
echo "Ready: ${{ steps.readiness.outputs.ready }}"
echo "Trajectories: ${{ steps.readiness.outputs.count }}"
echo "Model: ${{ steps.model_selection.outputs.base_model }}"