Skip to content

Commit bb2925c

Browse files
authored
Merge pull request #11 from Sahandfer/feature/consistentMI_reranker
Switched the reranker for ConsistentMI to LiteLLM from HF
2 parents e71ea9b + 42f0a7c commit bb2925c

5 files changed

Lines changed: 151 additions & 78 deletions

File tree

docs/docs/components/clients/consistentmi.md

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ ConsistentMI simulates clients in motivational interviewing (MI) sessions with c
2424

2525
1. **Load Profile**: Reads the character JSON (personas, beliefs, acceptable plans, motivation topics) and initializes `stage` and `receptivity`.
2626
2. **Initialize Prompts**: Builds a system prompt that anchors the client’s behavior/goal and injects personas + beliefs for consistency.
27-
3. **Track Topic Engagement**: Matches the therapist’s latest utterance to a motivation topic, then uses the topic graph distance to update `engagement` and count repeated off-topic turns.
27+
3. **Track Topic Engagement**: Matches the therapist’s latest utterance to a motivation topic using a reranker-backed topic matcher, then uses the topic graph distance to update `engagement` and count repeated off-topic turns. If reranking is unavailable or returns no valid scores, ConsistentMI falls back to lexical matching.
2828
4. **Verify Motivation (Optional)**: If the therapist addresses the client’s core motivation, the client enters a short `Motivation` state for an acknowledging response.
2929
5. **Sample a Stage-Consistent Action**: An LLM predicts an action distribution conditioned on recent context and the current stage.
3030
6. **Select Grounding Detail**: For actions like `Inform/Downplay/Blame/Hesitate/Plan`, the client selects a relevant persona/belief/plan (only when the therapist asks a question) to ground the next reply.
@@ -58,16 +58,43 @@ response = client.generate_response(
5858
print(response)
5959
```
6060

61+
> ⚠️ **Hint:**
62+
>
63+
> - ConsistentMI use a local reranker served through vLLM's OpenAI-compatible `/rerank` endpoint.
64+
> - Set `LOCAL_BASE_URL` and `LOCAL_API_KEY` in `.env`; PatientHub reuses them for the reranker.
65+
> - Use `reranker_model_type=LOCAL`.
66+
> - Set `reranker_model_name` to the LiteLLM vLLM route, e.g. `hosted_vllm/BAAI/bge-reranker-v2-m3`.
67+
> - If the reranker server runs on the same machine, prefer `127.0.0.1` over `0.0.0.0` in `LOCAL_BASE_URL`.
68+
6169
## Configuration
6270

63-
| Option | Description | Default |
64-
| ------------------ | -------------------------------- | ---------------------------------------------- |
65-
| `prompt_path` | Path to prompt file | `data/prompts/client/consistentMI.yaml` |
66-
| `data_path` | Path to character file | `data/characters/ConsistentMI.json` |
67-
| `data_idx` | Character index | `0` |
68-
| `topics_path` | Topics from Wiki | `data/resources/ConsistentMI/topics.json` |
69-
| `topic_graph_path` | Correlation between topics | `data/resources/ConsistentMI/topic_graph.json` |
70-
| `model_retriever` | retrieve the most relevant topic | None |
71+
| Option | Description | Default |
72+
| --------------------- | ------------------------------------------------------- | ---------------------------------------------- |
73+
| `prompt_path` | Path to prompt file | `data/prompts/client/consistentMI.yaml` |
74+
| `data_path` | Path to character file | `data/characters/ConsistentMI.json` |
75+
| `data_idx` | Character index | `0` |
76+
| `topics_path` | Topics from Wiki | `data/resources/ConsistentMI/topics.json` |
77+
| `topic_graph_path` | Correlation between topics | `data/resources/ConsistentMI/topic_graph.json` |
78+
| `reranker_model_type` | Provider key for topic reranking | `LOCAL` |
79+
| `reranker_model_name` | LiteLLM model route for the reranker | `hosted_vllm/BAAI/bge-reranker-v2-m3` |
80+
81+
### Local Reranker Example
82+
83+
```yaml
84+
client:
85+
agent_name: consistentMI
86+
model_type: OPENAI
87+
model_name: gpt-4o
88+
reranker_model_type: LOCAL
89+
reranker_model_name: hosted_vllm/BAAI/bge-reranker-v2-m3
90+
```
91+
92+
With a local vLLM reranker server, your `.env` should contain:
93+
94+
```bash
95+
LOCAL_BASE_URL=http://127.0.0.1:7891/v1
96+
LOCAL_API_KEY=EMPTY
97+
```
7198

7299
## Character Data Format
73100

docs/docs/getting-started/configuration.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ For example,
2222
OPENAI_API_KEY=your_openai_key
2323
OPENAI_BASE_URL=https://api.openai.com
2424

25-
# For VLLM (n this case, model_type = VLLM)
26-
VLLM_BASE_URL=http://127.0.0.1
27-
VLLM_API_KEY=None
25+
# For local OpenAI-compatible servers (model_type = LOCAL)
26+
LOCAL_BASE_URL=http://127.0.0.1:8000/v1
27+
LOCAL_API_KEY=EMPTY
2828
```
2929

30+
`model_type` is used to select the environment-variable namespace. For example, `model_type=LOCAL` makes PatientHub read `LOCAL_BASE_URL` and `LOCAL_API_KEY`.
31+
3032
## Model Configuration
3133

3234
### Using OpenAI (Default)
@@ -65,9 +67,14 @@ config = {
6567
```yaml
6668
client:
6769
agent_name: consistentMI
68-
initial_stage: precontemplation # precontemplation, contemplation, preparation, action
70+
model_type: OPENAI
71+
model_name: gpt-4o
72+
reranker_model_type: LOCAL
73+
reranker_model_name: hosted_vllm/BAAI/bge-reranker-v2-m3
6974
```
7075
76+
`ConsistentMI` uses the main `model_type` / `model_name` pair for response generation and a separate `reranker_model_type` / `reranker_model_name` pair for topic matching. The reranker currently reuses `LOCAL_BASE_URL` and `LOCAL_API_KEY`.
77+
7178
#### SimPatient
7279

7380
```yaml

docs/docs/getting-started/installation.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,44 @@ LOCAL_API_KEY=EMPTY
7575

7676
Then set your config to use `model_type=LOCAL` and `model_name` to the model name exposed by your vLLM server.
7777

78+
### Local Reranker Models via vLLM
79+
80+
`ConsistentMI` can also use a local reranker served by vLLM's OpenAI-compatible `/rerank` endpoint.
81+
82+
1) Start a reranker model with vLLM:
83+
84+
```bash
85+
vllm serve BAAI/bge-reranker-v2-m3 --host 0.0.0.0 --port 7891
86+
```
87+
88+
2) Point `LOCAL_BASE_URL` at the reranker server:
89+
90+
```bash
91+
LOCAL_BASE_URL=http://127.0.0.1:7891/v1
92+
LOCAL_API_KEY=EMPTY
93+
```
94+
95+
3) Use the LiteLLM vLLM route in `ConsistentMI`:
96+
97+
```yaml
98+
client:
99+
agent_name: consistentMI
100+
reranker_model_type: LOCAL
101+
reranker_model_name: hosted_vllm/BAAI/bge-reranker-v2-m3
102+
```
103+
104+
:::tip Localhost vs 0.0.0.0
105+
Use `0.0.0.0` for the server listen address, but use `127.0.0.1` or the machine's real IP in `LOCAL_BASE_URL`.
106+
:::
107+
108+
:::tip Proxy settings
109+
If your shell exports `http_proxy` or `https_proxy`, local requests to the reranker can be sent to the proxy instead of your vLLM server. For local testing, either unset those variables or set:
110+
111+
```bash
112+
export NO_PROXY=127.0.0.1,localhost
113+
```
114+
:::
115+
78116
:::note vLLM fails to start
79117
it’s usually a CUDA/driver mismatch on the serving machine—check your NVIDIA driver/CUDA runtime and use a vLLM version compatible with your environment.
80118
:::

patienthub/clients/consistentMI.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ class ConsistentMIClientConfig(APIModelConfig):
3737
prompt_path: str = "data/prompts/client/consistentMI.yaml"
3838
data_path: str = "data/characters/ConsistentMI.json"
3939
topics_path: str = "data/resources/ConsistentMI/topics.json"
40-
topic_graph_path: str = "data/resources/ConsistentMI/topic_graph.json"
40+
topic_graph_path: str = "data/resources/ConsistentMI/topic_graph.json"
41+
reranker_model_type: str = "LOCAL"
42+
reranker_model_name: str = "hosted_vllm/BAAI/bge-reranker-v2-m3"
4143
data_idx: int = 0
4244

4345

@@ -186,7 +188,7 @@ class TopicMatcher:
186188
def __init__(self, configs: Dict[str, Any]):
187189
self.topic_graph = load_json(configs.topic_graph_path)
188190
self.reranker = (
189-
get_reranker(configs.model_retriever) if configs.model_retriever else None
191+
get_reranker(configs)
190192
)
191193
self.all_topics = self.extract_all_topics()
192194
self.topic_passages: List[str] = []
@@ -230,6 +232,7 @@ def find_related_topics(self, query: str, top_k: int = 5) -> List[str]:
230232
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[
231233
:top_k
232234
]
235+
print(f"Related topics: {[self.all_topics[i] for i in top_indices]}")
233236
return [self.all_topics[i] for i in top_indices]
234237

235238
def score_passages(self, query: str) -> Optional[List[float]]:

patienthub/utils/models.py

Lines changed: 61 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dotenv import load_dotenv
55
from dataclasses import dataclass
66
from typing import Any, List, Optional, Dict
7-
from litellm import completion, supports_response_schema, completion_cost
7+
from litellm import completion, supports_response_schema, completion_cost, rerank
88

99
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
1010

@@ -91,82 +91,80 @@ def get(name, default=None):
9191

9292
@dataclass
9393
class Reranker:
94-
def __init__(self, tokenizer: Any, model: Any, device: Any):
95-
self.tokenizer = tokenizer
96-
self.model = model
97-
self.device = device
98-
99-
def score(
100-
self, query: str, passages: List[str], max_length: int = 512
101-
) -> Optional[List[float]]:
102-
"""Score (query, passage) pairs. Higher = more relevant."""
103-
if not passages:
94+
"""Reranker backed by LiteLLM's hosted_vllm provider."""
95+
96+
model_name: str
97+
api_base: Optional[str] = None
98+
api_key: Optional[str] = None
99+
100+
@staticmethod
101+
def read_field(obj: Any, name: str, default: Any = None) -> Any:
102+
if isinstance(obj, dict):
103+
return obj.get(name, default)
104+
return getattr(obj, name, default)
105+
106+
@classmethod
107+
def extract_scores(cls, response: Any, total_docs: int) -> Optional[List[float]]:
108+
scores = [0.0] * total_docs
109+
results = cls.read_field(response, "results", []) or []
110+
valid_count = 0
111+
112+
for item in results:
113+
index = cls.read_field(item, "index")
114+
relevance_score = cls.read_field(item, "relevance_score")
115+
if relevance_score is None:
116+
relevance_score = cls.read_field(item, "score")
117+
118+
try:
119+
index = int(index)
120+
relevance_score = float(relevance_score)
121+
except (TypeError, ValueError):
122+
continue
123+
124+
if 0 <= index < total_docs:
125+
scores[index] = relevance_score
126+
valid_count += 1
127+
128+
if valid_count == 0:
104129
return None
105130

106-
pairs = [(query, passage) for passage in passages]
131+
return scores
107132

108-
try:
109-
return self.compute_scores(pairs, max_length)
110-
except Exception:
133+
def score(self, query: str, passages: List[str]) -> Optional[List[float]]:
134+
"""Score passages through LiteLLM's rerank endpoint."""
135+
if not passages:
111136
return None
112137

113-
def compute_scores(self, pairs: List[tuple], max_length: int) -> List[float]:
114-
"""Compute relevance scores for query-passage pairs."""
115-
import torch
116-
117-
with torch.no_grad():
118-
inputs = self.tokenizer(
119-
pairs,
120-
padding=True,
121-
truncation=True,
122-
return_tensors="pt",
123-
max_length=max_length,
138+
try:
139+
response = rerank(
140+
model=self.model_name,
141+
query=query,
142+
documents=passages,
143+
top_n=len(passages),
144+
return_documents=False,
145+
api_base=self.api_base,
146+
api_key=self.api_key,
124147
)
125-
inputs = {k: v.to(self.device) for k, v in inputs.items()}
126-
outputs = self.model(**inputs, return_dict=True)
127-
logits = outputs.logits.view(-1).float()
128-
return torch.sigmoid(logits).tolist()
129-
130-
131-
def get_device(device_index: int):
132-
import torch
133-
134-
try:
135-
device_index = int(device_index)
136-
except Exception:
137-
device_index = 0
138-
139-
if torch.cuda.is_available() and device_index >= 0:
140-
return torch.device(f"cuda:{device_index}")
141-
return torch.device("cpu")
142-
143-
144-
def load_reranker_model(model_name: str, device: Any):
145-
"""Load tokenizer and model for reranking."""
146-
from transformers import AutoModelForSequenceClassification, AutoTokenizer
148+
except Exception:
149+
return None
147150

148-
tokenizer = AutoTokenizer.from_pretrained(model_name)
149-
model = AutoModelForSequenceClassification.from_pretrained(model_name)
150-
model.to(device)
151-
model.eval()
152-
return tokenizer, model
151+
return self.extract_scores(response, len(passages))
153152

154153

155154
def get_reranker(configs: Any) -> Optional[Reranker]:
156-
"""Get a Reranker instance from config, or None if unavailable."""
155+
"""Get a LOCAL reranker backed by LiteLLM's hosted_vllm provider."""
157156

158157
def get(name, default=None):
159158
return get_config_value(configs, name, default)
160159

161-
model_type = get("model_type")
162-
model_name = get("model_name")
160+
model_type = get("reranker_model_type")
161+
model_name = get("reranker_model_name")
163162

164-
if model_type not in ("huggingface", "local") or not model_name:
163+
if model_type != "LOCAL" or not model_name:
165164
return None
166165

167-
try:
168-
device = get_device(get("device", 0))
169-
tokenizer, model = load_reranker_model(model_name, device)
170-
return Reranker(tokenizer=tokenizer, model=model, device=device)
171-
except Exception:
172-
return None
166+
return Reranker(
167+
model_name=model_name,
168+
api_base=os.environ.get("LOCAL_BASE_URL"),
169+
api_key=os.environ.get("LOCAL_API_KEY"),
170+
)

0 commit comments

Comments
 (0)