Skip to content

rajpurkarlab/Clinical-RLVR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Open-Ended Clinical Text Generation for Acute Care: Applying Reinforcement Learning with Clinically Grounded Rewards

Official code for the paper accepted at CHIL 2026 (7th Annual Conference on Health, Inference, and Learning, June 28–30, 2026, Seattle, WA).

Minjia Wang, Luyang Luo, Sung Eun Kim, Fang Cao, David A Kim, Pranav Rajpurkar

Forked from: huggingface/open-r1

Table of Contents

  1. Overview
  2. Installation
  3. Training Models
  4. Inference and Evaluation
  5. Medical Tasks
  6. Acknowledgements
  7. Citation

Overview

This repository demonstrates that reinforcement learning with verifiable rewards (RLVR) can extend beyond closed-ended tasks to open-ended clinical text generation in emergency and critical care settings. Using Group Relative Policy Optimization (GRPO) with clinically grounded rewards, we train compact 7–8B parameter models on three acute care tasks:

  • Multiple Disease Diagnosis: Generate comprehensive diagnosis lists from ICU records (MIMIC-III), handling multiple concurrent conditions with clinical language variation
  • Treatment Plan Generation: Produce detailed treatment plans from clinical transcriptions (MTSamples), synthesizing patient-specific therapeutic recommendations
  • Discharge Instructions Generation: Create comprehensive discharge instructions from Electronic Health Records (DischargeMe), providing clear guidance for patient care after hospital discharge

Our approach uses two reward designs matched to different output types:

  • Equivalence-based rewards for diagnosis generation that account for medical synonyms, abbreviations, and varying specificity
  • Rubric-based rewards for treatment planning that score outputs on accuracy, completeness, and clarity

Key Features

  • GRPO Training: Train models with Group Relative Policy Optimization using clinically grounded rewards
  • LLM Judge Evaluation: Automated quality assessment using task-specific metrics and LLM judges
  • MIMIC-III Integration: Pre-processing and dataset preparation for ICU diagnosis generation
  • MTSamples Support: Treatment plan generation from multi-specialty clinical transcriptions
  • Multi-GPU Support: Single and multi-node training via SLURM

Installation

Caution

Libraries rely on CUDA 12.4. If you see errors related to segmentation faults, double-check the version your system is running with nvcc --version. Install CUDA 12.4 with the following commands:

wget https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run
sh cuda_12.4.0_550.54.14_linux.run --silent --toolkit --toolkitpath=$HOME/cuda-12.4
export CUDA_HOME=$HOME/cuda-12.4
export PATH=$CUDA_HOME/bin:$PATH
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH

To run the code in this project, first create a Python virtual environment using uv:

uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --upgrade pip

Next, install vLLM and FlashAttention:

uv pip install vllm==0.8.5.post1
uv pip install setuptools && uv pip install flash_attn==2.7.4.post1 --no-build-isolation

Then install the remaining dependencies:

GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]"

Log into your Weights and Biases accounts:

wandb login

Finally, check whether your system has Git LFS installed:

git-lfs --version

If it isn't installed, run:

sudo apt-get install git-lfs

Environment Variables

Create a .env file in the project root (excluded from git) or export the variables in your shell profile:

# Required for LLM-judge evaluation (Azure OpenAI)
export AZURE_OPENAI_ENDPOINT="https://<your-resource>.openai.azure.com/"
export AZURE_OPENAI_API_KEY="<your-key>"
export AZURE_OPENAI_API_VERSION="<your-version>"

# Optional: cache directories (default to ~/.cache/huggingface and ~/.cache/vllm)
export HF_HOME="/path/to/huggingface/cache"
export VLLM_CACHE_ROOT="/path/to/vllm/cache"

# Optional: override SLURM defaults per-job
export VENV_PATH="/path/to/openr1"        # default: <project_root>/openr1
export CONFIG_FILE="recipes/MIMIC-III/disease_diagnosis.yaml"
export ACCELERATOR_FILE="recipes/accelerate_configs/zero3.yaml"

Note

AZURE_OPENAI_* variables are only needed for evaluation with LLM judges. Training runs do not require them.

Training Models

GRPO Training

To train a model with GRPO (Group Relative Policy Optimization) on medical datasets, use the provided SLURM script:

sbatch slurm/train.slurm

The training script (slurm/train.slurm) supports both single-node and multi-node training:

  • Single node mode: vLLM runs on GPU 0, training on remaining GPUs
  • Multi-node mode: vLLM runs on the last node, training on other nodes

Configuration

Training configurations are stored in the recipes/ directory. For example, to train on MIMIC-III disease diagnosis:

# recipes/MIMIC-III/disease_diagnosis.yaml
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
model_revision: main
dataset_name: MIMIC-III
task_name: multiple_disease_diagnosis
use_vllm: true
learning_rate: 1.0e-06
num_train_epochs: 10
per_device_train_batch_size: 4
gradient_accumulation_steps: 16

Single Node Training (1 node, 4 GPUs)

sbatch --nodes=1 --gres=gpu:4 slurm/train.slurm

Multi-Node Training (2+ nodes)

sbatch --nodes=2 slurm/train.slurm

The script automatically:

  • Configures the distributed training setup
  • Starts the vLLM server for policy rollouts
  • Monitors health checks
  • Handles cleanup after training

Training Script Details

The main training configurations are in recipes/.

Inference and Evaluation

To generate outputs and evaluate models with LLM judges, use:

sbatch slurm/inference_and_evaluate.slurm

This script performs two main operations:

1. Inference

Generates model outputs on test datasets using:

  • vLLM: For local model inference
  • API: For OpenAI models
  • Batch: For batch API processing

Configure inference in the script:

MODEL="Qwen/Qwen2.5-7B-Instruct"
DATASET_NAME="MTSamples"
TASK_NAME="treatment_plan_generation"
INFERENCE_OUTPUT_PATH=data/$DATASET_NAME/$TASK_NAME/$MODEL/test_output.jsonl

2. Evaluation with LLM Judges

After inference, the script automatically evaluates outputs using LLM judges (e.g., GPT-4.1):

EVALUATION_MODEL="gpt-4o-1120-batch"

The evaluation uses task-specific metrics defined in src/open_r1/tasks/:

For Disease Diagnosis:

  • Jaccard similarity and F1 scores computed from LLM-judged diagnosis matches
  • Equivalence-based matching accounting for synonyms, abbreviations, and specificity variation
  • Precision and recall to balance over-generation (hallucinated diagnoses) and under-generation (missed conditions)

For Treatment Plans:

  • Rubric-based scoring on three dimensions (1-5 scale):
    • Accuracy: Alignment with clinical guidelines and medical evidence
    • Completeness: Inclusion of dosing, duration, follow-up, and other salient details
    • Clarity: Logical organization and comprehensibility for treating team

Running Inference and Evaluation

# Submit the job
sbatch slurm/inference_and_evaluate.slurm

# Check job status
squeue -u $USER

# View outputs
cat logs/infer_and_eval-{job_id}.out

Medical Tasks

This repository focuses on two open-ended clinical text generation tasks for acute care:

1. Multiple Disease Diagnosis (MIMIC-III)

Generate comprehensive diagnosis lists from ICU patient records. The task requires:

  • Input: Demographics, presenting complaints, physical examination findings, prescriptions, and laboratory results
  • Output: Free-text list of diagnoses accounting for multiple concurrent conditions
  • Challenge: Clinical language variation including synonyms (e.g., "MI" vs. "myocardial infarction"), abbreviations, and varying specificity (e.g., "pneumonia" matching "bacterial pneumonia")
  • Reward: Equivalence-based rewards using LLM judge (GPT-4.1) to evaluate diagnosis set overlap (Jaccard similarity, F1) while accounting for clinical semantics
  • Dataset: MIMIC-III critical care database with ICD-coded diagnoses

2. Treatment Plan Generation (MTSamples)

Generate detailed, actionable treatment plans from clinical transcriptions. The task requires:

  • Input: Structured patient information including medical history, current presentation, and comorbidities
  • Output: Free-form treatment recommendations tailored to patient circumstances
  • Challenge: Multi-dimensional quality assessment without single correct answers
  • Reward: Rubric-based rewards scoring on accuracy, completeness, and clarity
  • Dataset: MTSamples medical transcription repository spanning multiple specialties

3. Discharge Instructions Generation (DischargeMe)

Generate comprehensive discharge instructions from Electronic Health Records. The task requires:

  • Input: Patient Electronic Health Record including medical history, hospital course, and clinical findings
  • Output: Clear, actionable discharge instructions for post-hospital care
  • Challenge: Balancing completeness with readability while ensuring medical correctness
  • Reward: Rubric-based rewards scoring on completeness, correctness, and readability
  • Dataset: DischargeMe dataset with real-world EHRs and corresponding discharge instructions

Key Differences from Original Open R1

This fork extends RLVR to open-ended clinical text generation with the following contributions:

  1. Clinically Grounded Rewards: Two novel reward designs for free-text clinical outputs:
    • Equivalence-based rewards for diagnosis lists that handle medical synonyms, abbreviations, and varying specificity
    • Rubric-based rewards for treatment plans and discharge instructions that score on multiple quality dimensions
  2. Open-Ended Medical Tasks: Multiple disease diagnosis, treatment planning, and discharge instructions generation—tasks requiring clinical reasoning without single correct answers
  3. MIMIC-III Integration: ICU patient records for diagnosis generation with ICD-coded ground truth
  4. MTSamples Support: Multi-specialty clinical transcriptions for treatment plan generation
  5. DischargeMe Integration: Electronic Health Records with discharge instructions for post-hospital care planning
  6. Clinical Validation: Qualitative review demonstrating more comprehensive assessments, accurate diagnostic capture, and fewer dangerous errors
  7. Compact Model Performance: 7–8B models matching or exceeding GPT-5, GPT-4o, and GPT-4.1 on acute care tasks

Acknowledgements

This project is built upon HuggingFace's Open R1, a fully open reproduction of DeepSeek-R1. We are grateful to:

  • The HuggingFace team for creating the Open R1 framework and GRPO implementation
  • The DeepSeek team for the original DeepSeekMath and R1 models
  • The vLLM team for high-performance inference tooling
  • The MIMIC team at Beth Israel Deaconess Medical Center for providing access to critical care data
  • The MTSamples repository for multi-specialty clinical transcriptions
  • The DischargeMe dataset creators for providing Electronic Health Records with discharge instructions
  • OpenAI for GPT-4 evaluation capabilities

Citation

If you use this work, please cite:

@InProceedings{pmlr-v297-wang26a,
  title     = {Open-Ended Clinical Text Generation for Acute Care: Applying Reinforcement Learning with Clinically Grounded Rewards},
  author    = {Wang, Minjia and Luo, Luyang and Kim, Sung Eun and Cao, Fang and Kim, David A and Rajpurkar, Pranav},
  booktitle = {Proceedings of the Conference on Health, Inference, and Learning},
  year      = {2026},
  volume    = {297},
  series    = {Proceedings of Machine Learning Research},
  month     = {28--30 Jun},
  publisher = {PMLR},
}

License

This project inherits the Apache-2.0 license from the original Open-R1 repository.

About

Code for "Open-Ended Clinical Text Generation for Acute Care: Applying Reinforcement Learning with Clinically Grounded Rewards"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors