Feed RL Training #21
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: Feed RL Training | |
| # Trigger options: | |
| # 1. On schedule (daily training at 2 AM UTC) | |
| # 2. Via repository_dispatch (from debug endpoint) | |
| # 3. Manual workflow_dispatch (testing from GitHub UI) | |
| on: | |
| schedule: | |
| # Daily at 2 AM UTC | |
| - cron: '0 2 * * *' | |
| repository_dispatch: | |
| types: [trigger-training] | |
| workflow_dispatch: | |
| inputs: | |
| batch_id: | |
| description: 'Training batch ID (optional, auto-generated if not provided)' | |
| required: false | |
| type: string | |
| window_id: | |
| description: 'Window ID to train (optional, auto-detected if not provided)' | |
| required: false | |
| type: string | |
| force: | |
| description: 'Force training even if not ready (for testing)' | |
| required: false | |
| type: boolean | |
| default: false | |
| base_model: | |
| description: 'Base model to use (e.g., Qwen/Qwen2.5-7B-Instruct)' | |
| required: false | |
| type: string | |
| env: | |
| PYTHON_VERSION: '3.11' | |
| # Prevent concurrent training runs | |
| concurrency: | |
| group: training-pipeline | |
| cancel-in-progress: false # Wait for current run to finish | |
| permissions: | |
| contents: read | |
| jobs: | |
| train: | |
| if: ${{ vars.ENABLE_RL_TRAINING == 'true' }} | |
| name: Train RL Model | |
| runs-on: ubuntu-latest | |
| timeout-minutes: 360 # 6 hours max | |
| steps: | |
| - name: Checkout code | |
| uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 | |
| - name: Setup Python | |
| uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 | |
| with: | |
| python-version: ${{ env.PYTHON_VERSION }} | |
| cache: 'pip' | |
| cache-dependency-path: 'packages/training/scripts/rl/requirements.txt' | |
| - name: Install dependencies | |
| working-directory: packages/training/scripts/rl | |
| run: | | |
| pip install --upgrade pip | |
| pip install -r requirements.txt | |
| pip install -e . | |
| - name: Verify installation | |
| run: | | |
| python --version | |
| pip list | grep -E "(art|asyncpg)" | |
| - name: Check training readiness | |
| id: readiness | |
| env: | |
| DATABASE_URL: ${{ secrets.DATABASE_URL }} | |
| run: | | |
| python -c " | |
| import asyncio | |
| import asyncpg | |
| import os | |
| import sys | |
| async def check(): | |
| try: | |
| pool = await asyncpg.create_pool( | |
| os.getenv('DATABASE_URL'), | |
| min_size=1, | |
| max_size=2, | |
| timeout=30 | |
| ) | |
| # Count scored trajectories ready for training | |
| count = await pool.fetchval(''' | |
| SELECT COUNT(*) FROM trajectories | |
| WHERE \"isTrainingData\" = true | |
| AND \"usedInTraining\" = false | |
| AND \"aiJudgeReward\" IS NOT NULL | |
| AND \"stepsJson\" IS NOT NULL | |
| AND \"stepsJson\"::text != 'null' | |
| AND \"stepsJson\"::text != '[]' | |
| ''') | |
| print(f'✅ Database connected') | |
| print(f'📊 Trajectories ready for training: {count}') | |
| # Need 100 bundles minimum | |
| min_required = 100 | |
| ready = count >= min_required | |
| if ready: | |
| print(f'✅ READY: {count} >= {min_required}') | |
| else: | |
| print(f'⏳ NOT READY: {count} < {min_required} (need {min_required - count} more)') | |
| with open(os.getenv('GITHUB_OUTPUT'), 'a') as f: | |
| f.write(f'ready={str(ready).lower()}\\n') | |
| f.write(f'count={count}\\n') | |
| await pool.close() | |
| except Exception as e: | |
| print(f'❌ Error checking readiness: {e}', file=sys.stderr) | |
| with open(os.getenv('GITHUB_OUTPUT'), 'a') as f: | |
| f.write('ready=false\\n') | |
| f.write('count=0\\n') | |
| sys.exit(1) | |
| asyncio.run(check()) | |
| " | |
| - name: Get batch info | |
| id: batch | |
| env: | |
| # From workflow_dispatch inputs: | |
| BATCH_ID_INPUT: ${{ inputs.batch_id || '' }} | |
| WINDOW_ID_INPUT: ${{ inputs.window_id || '' }} | |
| FORCE_INPUT: ${{ inputs.force || 'false' }} | |
| BASE_MODEL_INPUT: ${{ inputs.base_model || '' }} | |
| # From repository_dispatch payload: | |
| BATCH_ID_PAYLOAD: ${{ github.event.client_payload.batch_id || '' }} | |
| WINDOW_ID_PAYLOAD: ${{ github.event.client_payload.window_id || '' }} | |
| FORCE_PAYLOAD: ${{ github.event.client_payload.force || 'false' }} | |
| DATABASE_URL: ${{ secrets.DATABASE_URL }} | |
| run: | | |
| # Determine batch_id | |
| if [ -n "$BATCH_ID_PAYLOAD" ]; then | |
| echo "batch_id=$BATCH_ID_PAYLOAD" >> $GITHUB_OUTPUT | |
| elif [ -n "$BATCH_ID_INPUT" ]; then | |
| echo "batch_id=$BATCH_ID_INPUT" >> $GITHUB_OUTPUT | |
| else | |
| echo "batch_id=batch-$(date +%s)" >> $GITHUB_OUTPUT | |
| fi | |
| # Determine window_id | |
| if [ -n "$WINDOW_ID_PAYLOAD" ]; then | |
| echo "window_id=$WINDOW_ID_PAYLOAD" >> $GITHUB_OUTPUT | |
| elif [ -n "$WINDOW_ID_INPUT" ]; then | |
| echo "window_id=$WINDOW_ID_INPUT" >> $GITHUB_OUTPUT | |
| else | |
| WINDOW=$(date -u +"%Y-%m-%dT%H:00") | |
| echo "window_id=$WINDOW" >> $GITHUB_OUTPUT | |
| fi | |
| # Determine force flag | |
| if [ "$FORCE_PAYLOAD" = "true" ] || [ "$FORCE_INPUT" = "true" ]; then | |
| echo "force=true" >> $GITHUB_OUTPUT | |
| else | |
| echo "force=false" >> $GITHUB_OUTPUT | |
| fi | |
| # Determine base model | |
| if [ -n "$BASE_MODEL_INPUT" ]; then | |
| echo "base_model=$BASE_MODEL_INPUT" >> $GITHUB_OUTPUT | |
| else | |
| echo "base_model=Qwen/Qwen2.5-7B-Instruct" >> $GITHUB_OUTPUT | |
| fi | |
| # Generate model version | |
| MODEL_VERSION=$(python -c " | |
| import asyncio | |
| import asyncpg | |
| import os | |
| import sys | |
| async def get_version(): | |
| try: | |
| pool = await asyncpg.create_pool(os.getenv('DATABASE_URL'), timeout=10) | |
| latest = await pool.fetchval(''' | |
| SELECT version FROM trained_models | |
| WHERE status IN ('ready', 'deployed') | |
| ORDER BY \"createdAt\" DESC | |
| LIMIT 1 | |
| ''') | |
| if latest: | |
| parts = latest.strip('v').split('.') | |
| patch = int(parts[2]) + 1 | |
| version = f'v{parts[0]}.{parts[1]}.{patch}' | |
| else: | |
| version = 'v1.0.0' | |
| print(version) | |
| await pool.close() | |
| except Exception as e: | |
| import time | |
| version = f'v1.0.{int(time.time()) % 10000}' | |
| print(version) | |
| asyncio.run(get_version()) | |
| " 2>/dev/null) | |
| echo "model_version=$MODEL_VERSION" >> $GITHUB_OUTPUT | |
| echo "source=github_cron" >> $GITHUB_OUTPUT | |
| - name: Skip if not ready (unless forced) | |
| if: steps.readiness.outputs.ready != 'true' && steps.batch.outputs.force != 'true' | |
| run: | | |
| echo "⏭️ Not ready for training and force=false" | |
| echo "Trajectories: ${{ steps.readiness.outputs.count }}" | |
| echo "Required: 100 (minimum bundles)" | |
| exit 0 | |
| - name: Update batch status to training | |
| if: steps.readiness.outputs.ready == 'true' || steps.batch.outputs.force == 'true' | |
| env: | |
| DATABASE_URL: ${{ secrets.DATABASE_URL }} | |
| BATCH_ID: ${{ steps.batch.outputs.batch_id }} | |
| run: | | |
| python -c " | |
| import asyncio | |
| import asyncpg | |
| import os | |
| async def update(): | |
| pool = await asyncpg.create_pool(os.getenv('DATABASE_URL')) | |
| batch_id = os.getenv('BATCH_ID') | |
| async with pool.acquire() as conn: | |
| async with conn.transaction(): | |
| await conn.execute(''' | |
| INSERT INTO training_batches ( | |
| \"batchId\", id, status, \"startedAt\", \"createdAt\" | |
| ) VALUES ( | |
| \$1, \$1, 'training', NOW(), NOW() | |
| ) | |
| ON CONFLICT (\"batchId\") | |
| DO UPDATE SET status = 'training', \"startedAt\" = NOW() | |
| ''', batch_id) | |
| print(f'✅ Batch {batch_id} status: training') | |
| await pool.close() | |
| asyncio.run(update()) | |
| " | |
| - name: Select base model | |
| id: model_selection | |
| if: steps.readiness.outputs.ready == 'true' || steps.batch.outputs.force == 'true' | |
| env: | |
| DATABASE_URL: ${{ secrets.DATABASE_URL }} | |
| BASE_MODEL_OVERRIDE: ${{ steps.batch.outputs.base_model }} | |
| run: | | |
| python -c " | |
| import asyncio | |
| import asyncpg | |
| import os | |
| import sys | |
| async def select_model(): | |
| try: | |
| pool = await asyncpg.create_pool(os.getenv('DATABASE_URL')) | |
| base_model_override = os.getenv('BASE_MODEL_OVERRIDE', '') | |
| # Count training bundles | |
| bundle_count = await pool.fetchval(''' | |
| SELECT COUNT(*) FROM trajectories | |
| WHERE \"isTrainingData\" = true | |
| AND \"usedInTraining\" = false | |
| AND \"aiJudgeReward\" IS NOT NULL | |
| AND \"stepsJson\" IS NOT NULL | |
| AND \"stepsJson\"::text != 'null' | |
| AND \"stepsJson\"::text != '[]' | |
| ''') | |
| print(f'📊 Bundle count: {bundle_count}') | |
| # Use override if provided, otherwise use default | |
| if base_model_override: | |
| base_model = base_model_override | |
| strategy = 'override' | |
| else: | |
| base_model = 'Qwen/Qwen2.5-7B-Instruct' | |
| strategy = 'default' | |
| print(f'📦 Selected model: {base_model}') | |
| print(f'📋 Strategy: {strategy}') | |
| with open(os.getenv('GITHUB_OUTPUT'), 'a') as f: | |
| f.write(f'base_model={base_model}\\n') | |
| f.write(f'strategy={strategy}\\n') | |
| f.write(f'bundle_count={bundle_count}\\n') | |
| await pool.close() | |
| except Exception as e: | |
| print(f'❌ Model selection failed: {e}', file=sys.stderr) | |
| with open(os.getenv('GITHUB_OUTPUT'), 'a') as f: | |
| f.write('base_model=Qwen/Qwen2.5-7B-Instruct\\n') | |
| f.write('strategy=fallback\\n') | |
| sys.exit(1) | |
| asyncio.run(select_model()) | |
| " | |
| - name: Run RL Training | |
| id: training | |
| if: steps.readiness.outputs.ready == 'true' || steps.batch.outputs.force == 'true' | |
| env: | |
| DATABASE_URL: ${{ secrets.DATABASE_URL }} | |
| BATCH_ID: ${{ steps.batch.outputs.batch_id }} | |
| WINDOW_ID: ${{ steps.batch.outputs.window_id }} | |
| MODEL_VERSION: ${{ steps.batch.outputs.model_version }} | |
| BASE_MODEL: ${{ steps.model_selection.outputs.base_model }} | |
| MODE: single | |
| MAX_EXAMPLES: "2000" | |
| MAX_STEPS_PER_TRAJECTORY: "20" | |
| MAX_SEQ_LENGTH: "8192" | |
| working-directory: packages/training/scripts/rl | |
| run: | | |
| echo "🚀 Starting training" | |
| echo "Batch ID: $BATCH_ID" | |
| echo "Window ID: $WINDOW_ID" | |
| echo "Model Version: $MODEL_VERSION" | |
| echo "Strategy: ${{ steps.model_selection.outputs.strategy }}" | |
| echo "Base Model: $BASE_MODEL" | |
| echo "Trajectories available: ${{ steps.readiness.outputs.count }}" | |
| echo "Bundle count: ${{ steps.model_selection.outputs.bundle_count }}" | |
| # Run trainer | |
| python src/training/feed_trainer.py | |
| - name: Update batch status to completed | |
| if: success() && (steps.readiness.outputs.ready == 'true' || steps.batch.outputs.force == 'true') | |
| env: | |
| DATABASE_URL: ${{ secrets.DATABASE_URL }} | |
| BATCH_ID: ${{ steps.batch.outputs.batch_id }} | |
| run: | | |
| python -c " | |
| import asyncio | |
| import asyncpg | |
| import os | |
| async def update(): | |
| pool = await asyncpg.create_pool(os.getenv('DATABASE_URL')) | |
| batch_id = os.getenv('BATCH_ID') | |
| await pool.execute(''' | |
| UPDATE training_batches | |
| SET status = 'completed', \"completedAt\" = NOW() | |
| WHERE \"batchId\" = \$1 | |
| ''', batch_id) | |
| print(f'✅ Batch {batch_id} completed') | |
| await pool.close() | |
| asyncio.run(update()) | |
| " | |
| - name: Update batch status to failed | |
| if: failure() && (steps.readiness.outputs.ready == 'true' || steps.batch.outputs.force == 'true') | |
| env: | |
| DATABASE_URL: ${{ secrets.DATABASE_URL }} | |
| BATCH_ID: ${{ steps.batch.outputs.batch_id }} | |
| run: | | |
| python -c " | |
| import asyncio | |
| import asyncpg | |
| import os | |
| async def update(): | |
| pool = await asyncpg.create_pool(os.getenv('DATABASE_URL')) | |
| batch_id = os.getenv('BATCH_ID') | |
| await pool.execute(''' | |
| UPDATE training_batches | |
| SET status = 'failed', error = 'GitHub Actions workflow failed' | |
| WHERE \"batchId\" = \$1 | |
| ''', batch_id) | |
| print(f'❌ Batch {batch_id} failed') | |
| await pool.close() | |
| asyncio.run(update()) | |
| " | |
| - name: Upload training logs | |
| if: always() | |
| uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a | |
| with: | |
| name: training-logs-${{ steps.batch.outputs.batch_id }} | |
| path: | | |
| packages/training/scripts/rl/logs/ | |
| packages/training/scripts/rl/*.log | |
| retention-days: 7 | |
| - name: Report status | |
| if: always() | |
| run: | | |
| echo "Training Status: ${{ job.status }}" | |
| echo "Batch ID: ${{ steps.batch.outputs.batch_id }}" | |
| echo "Window ID: ${{ steps.batch.outputs.window_id }}" | |
| echo "Ready: ${{ steps.readiness.outputs.ready }}" | |
| echo "Trajectories: ${{ steps.readiness.outputs.count }}" | |
| echo "Model: ${{ steps.model_selection.outputs.base_model }}" |