Skip to content

Commit 6b3665a

Browse files
authored
[manager] fix address when binding to 0 (#67)
1 parent b617bd2 commit 6b3665a

File tree

4 files changed

+23
-22
lines changed

4 files changed

+23
-22
lines changed

src/lib.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ impl Manager {
4141
py: Python<'_>,
4242
replica_id: String,
4343
lighthouse_addr: String,
44-
address: String,
44+
hostname: String,
4545
bind: String,
4646
store_addr: String,
4747
world_size: u64,
@@ -52,7 +52,7 @@ impl Manager {
5252
.block_on(manager::Manager::new(
5353
replica_id,
5454
lighthouse_addr,
55-
address,
55+
hostname,
5656
bind,
5757
store_addr,
5858
world_size,

src/manager.rs

+10-14
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ use std::sync::Arc;
1111
use std::time::Duration;
1212

1313
use anyhow::Result;
14-
use gethostname::gethostname;
1514
use tokio::sync::broadcast;
1615
use tokio::sync::Mutex;
1716
use tokio::task::JoinSet;
@@ -53,7 +52,7 @@ struct ManagerState {
5352
pub struct Manager {
5453
replica_id: String,
5554
lighthouse_addr: String,
56-
address: String,
55+
hostname: String,
5756
store_address: String,
5857
world_size: u64,
5958
state: Mutex<ManagerState>,
@@ -80,19 +79,20 @@ impl Manager {
8079
pub async fn new(
8180
replica_id: String,
8281
lighthouse_addr: String,
83-
address: String,
82+
hostname: String,
8483
bind: String,
8584
store_addr: String,
8685
world_size: u64,
8786
) -> Result<Arc<Self>> {
8887
let listener = tokio::net::TcpListener::bind(&bind).await?;
88+
let local_addr = listener.local_addr()?;
8989

9090
let (should_commit_tx, _) = broadcast::channel(16);
9191

9292
Ok(Arc::new(Self {
9393
replica_id: replica_id,
9494
lighthouse_addr: lighthouse_addr,
95-
address: address,
95+
hostname: hostname,
9696
store_address: store_addr,
9797
world_size: world_size,
9898
state: Mutex::new(ManagerState {
@@ -103,7 +103,7 @@ impl Manager {
103103
should_commit_count: HashSet::new(),
104104
should_commit_failures: HashSet::new(),
105105
}),
106-
local_addr: listener.local_addr()?,
106+
local_addr: local_addr,
107107
listener: Mutex::new(Some(listener)),
108108
}))
109109
}
@@ -122,11 +122,7 @@ impl Manager {
122122
}
123123

124124
pub fn address(&self) -> String {
125-
format!(
126-
"http://{}:{}",
127-
gethostname().into_string().unwrap(),
128-
self.local_addr.port()
129-
)
125+
format!("http://{}:{}", self.hostname, self.local_addr.port())
130126
}
131127

132128
async fn _run_grpc(self: Arc<Self>) -> Result<()> {
@@ -228,7 +224,7 @@ impl ManagerService for Arc<Manager> {
228224
room_id: room_id.clone(),
229225
requester: Some(QuorumMember {
230226
replica_id: self.replica_id.clone(),
231-
address: self.address.clone(),
227+
address: self.address(),
232228
store_address: self.store_address.clone(),
233229
step: req.step,
234230
world_size: self.world_size,
@@ -470,7 +466,7 @@ mod tests {
470466
let manager = Manager::new(
471467
"rep_id".to_string(),
472468
lighthouse.address(),
473-
"addr".to_string(),
469+
"localhost".to_string(),
474470
"[::]:0".to_string(),
475471
"store_addr".to_string(),
476472
1, // world size
@@ -493,7 +489,7 @@ mod tests {
493489
lighthouse_fut.abort();
494490

495491
assert_eq!(resp.quorum_id, 1);
496-
assert_eq!(resp.address, "addr".to_string());
492+
assert_eq!(resp.address, manager.address());
497493
assert_eq!(resp.store_address, "store_addr".to_string());
498494
assert_eq!(resp.max_step, 123);
499495
assert_eq!(resp.max_rank, Some(0));
@@ -525,7 +521,7 @@ mod tests {
525521
let manager = Manager::new(
526522
format!("rep_{}", replica_id),
527523
lighthouse_addr,
528-
"addr".to_string(),
524+
"localhost".to_string(),
529525
"[::]:0".to_string(),
530526
"store_addr".to_string(),
531527
1, // world size

torchft/manager.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(
100100
lighthouse_addr: Optional[str] = None,
101101
replica_id: Optional[str] = None,
102102
port: Optional[int] = None,
103+
hostname: str = socket.gethostname(),
103104
) -> None:
104105
"""
105106
Args:
@@ -122,6 +123,7 @@ def __init__(
122123
store_port: TCPStore port for this replica group
123124
lighthouse_addr: if rank==0, the address of the lighthouse server
124125
replica_id: if rank==0, the replica_id for this group
126+
hostname: if rank==0, the hostname to advertise to the lighthouse server
125127
"""
126128
self._load_state_dict = load_state_dict
127129
self._state_dict = state_dict
@@ -159,12 +161,9 @@ def _manager_state_dict() -> Dict[str, T]:
159161
self._manager: Optional[_Manager] = None
160162

161163
if rank == 0:
162-
hostname = socket.gethostname()
163-
164164
if port is None:
165165
port = int(os.environ.get(MANAGER_PORT_ENV, 0))
166166

167-
addr = f"http://{hostname}:{port}"
168167
bind = f"[::]:{port}"
169168
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]
170169

@@ -174,7 +173,7 @@ def _manager_state_dict() -> Dict[str, T]:
174173
self._manager = _Manager(
175174
replica_id=replica_id,
176175
lighthouse_addr=lighthouse_addr,
177-
address=addr,
176+
hostname=hostname,
178177
bind=bind,
179178
store_addr=f"{store_addr}:{store_port}",
180179
world_size=world_size,

torchft/torchft.pyi

+8-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class Manager:
2727
self,
2828
replica_id: str,
2929
lighthouse_addr: str,
30-
address: str,
30+
hostname: str,
3131
bind: str,
3232
store_addr: str,
3333
world_size: int,
@@ -36,6 +36,12 @@ class Manager:
3636
def shutdown(self) -> None: ...
3737

3838
class Lighthouse:
39-
def __init__(self, bind: str, min_replicas: int, join_timeout_ms: Optional[int] = None, quorum_tick_ms: Optional[int] = None) -> None: ...
39+
def __init__(
40+
self,
41+
bind: str,
42+
min_replicas: int,
43+
join_timeout_ms: Optional[int] = None,
44+
quorum_tick_ms: Optional[int] = None,
45+
) -> None: ...
4046
def address(self) -> str: ...
4147
def shutdown(self) -> None: ...

0 commit comments

Comments
 (0)