55import time
66from collections .abc import Generator , Iterator
77from concurrent .futures import ProcessPoolExecutor
8+ from contextlib import contextmanager
89from pathlib import Path
910from typing import Any
1011from unittest .mock import AsyncMock , MagicMock , patch
@@ -148,7 +149,7 @@ def memory_plugin(self) -> MemoryPlugin:
148149
149150 @pytest .fixture
150151 def memory_cgroup_context (
151- self , cgroup_stat_context : MagicMock
152+ self , cgroup_stat_context : MagicMock , tmp_path : Path
152153 ) -> Generator [MagicMock , None , None ]:
153154 """CGROUP stat context with memory/io cgroup v2 path mocks and related patches."""
154155 ctx = cgroup_stat_context
@@ -170,8 +171,10 @@ def mock_get_cgroup_path(subsys: str, cid: str) -> MagicMock:
170171
171172 ctx .agent .get_cgroup_path = mock_get_cgroup_path
172173
174+ sandbox_file = tmp_path / "fake_netns"
175+ sandbox_file .touch ()
173176 mock_container_data = {
174- "NetworkSettings" : {"SandboxKey" : "/var/run/docker/netns/fake" },
177+ "NetworkSettings" : {"SandboxKey" : str ( sandbox_file ) },
175178 }
176179
177180 with (
@@ -234,6 +237,135 @@ async def test_sysfs_mode_uses_instance_docker_client(
234237 mock_docker_cls .assert_not_called ()
235238
236239
240+ class TestMemoryPluginNamespaceValidation (BaseDockerIntrinsicTest ):
241+ """Tests for namespace path pre-validation before netstat_ns call."""
242+
243+ @pytest .fixture
244+ def memory_plugin (self ) -> MemoryPlugin :
245+ plugin = MemoryPlugin .__new__ (MemoryPlugin )
246+ plugin .local_config = {"agent" : {"docker-mode" : "default" }}
247+ plugin ._docker = AsyncMock ()
248+ return plugin
249+
250+ @contextmanager
251+ def _make_cgroup_context (
252+ self ,
253+ cgroup_stat_context : MagicMock ,
254+ sandbox_key : str ,
255+ ) -> Generator [tuple [MagicMock , MagicMock ], None , None ]:
256+ """Build a CGROUP stat context with configurable sandbox_key.
257+
258+ Pass a real filesystem path for sandbox_key — an existing path
259+ triggers netstat_ns, a non-existent one skips it.
260+ """
261+ ctx = cgroup_stat_context
262+ ctx .agent .get_cgroup_version = MagicMock (return_value = "2" )
263+
264+ mem_path = MagicMock ()
265+ mem_stat = MagicMock ()
266+ mem_stat .read_text .return_value = "inactive_file 0\n "
267+ mem_path .__truediv__ = MagicMock (return_value = mem_stat )
268+ io_path = MagicMock ()
269+ io_stat = MagicMock ()
270+ io_stat .read_text .return_value = ""
271+ io_path .__truediv__ = MagicMock (return_value = io_stat )
272+
273+ def mock_get_cgroup_path (subsys : str , cid : str ) -> MagicMock :
274+ if subsys == "memory" :
275+ return mem_path
276+ return io_path
277+
278+ ctx .agent .get_cgroup_path = mock_get_cgroup_path
279+
280+ mock_container_data = {
281+ "NetworkSettings" : {"SandboxKey" : sandbox_key },
282+ }
283+
284+ with (
285+ patch (
286+ "ai.backend.agent.docker.intrinsic.DockerContainer" ,
287+ ) as mock_container_cls ,
288+ patch (
289+ "ai.backend.agent.docker.intrinsic.read_sysfs" ,
290+ return_value = 1048576 ,
291+ ),
292+ patch (
293+ "ai.backend.agent.docker.intrinsic.netstat_ns" ,
294+ new_callable = AsyncMock ,
295+ ) as mock_netstat ,
296+ patch (
297+ "ai.backend.agent.docker.intrinsic.current_loop" ,
298+ ) as mock_loop ,
299+ ):
300+ mock_netstat .return_value = {
301+ "eth0" : MagicMock (bytes_recv = 4096 , bytes_sent = 8192 ),
302+ }
303+ mock_container_instance = AsyncMock ()
304+ mock_container_instance .show .return_value = mock_container_data
305+ mock_container_cls .return_value = mock_container_instance
306+ mock_loop .return_value .run_in_executor = AsyncMock (return_value = 0 )
307+ yield ctx , mock_netstat
308+
309+ async def test_nonexistent_namespace_path_returns_zero_net_stats (
310+ self ,
311+ memory_plugin : MemoryPlugin ,
312+ cgroup_stat_context : MagicMock ,
313+ tmp_path : Path ,
314+ ) -> None :
315+ """When namespace path does not exist, net stats should be 0 but other stats collected."""
316+ gone_path = tmp_path / "nonexistent_netns"
317+ with self ._make_cgroup_context (
318+ cgroup_stat_context ,
319+ sandbox_key = str (gone_path ),
320+ ) as (ctx , mock_netstat ):
321+ results = await memory_plugin .gather_container_measures (ctx , ["cid_001" ])
322+ mock_netstat .assert_not_called ()
323+ # mem stats should be collected (read_sysfs returns 1048576)
324+ assert results [0 ].per_container ["cid_001" ].value == 1048576
325+ # net_rx and net_tx should be 0
326+ assert results [3 ].per_container ["cid_001" ].value == 0
327+ assert results [4 ].per_container ["cid_001" ].value == 0
328+
329+ async def test_empty_sandbox_key_returns_zero_net_stats (
330+ self ,
331+ memory_plugin : MemoryPlugin ,
332+ cgroup_stat_context : MagicMock ,
333+ ) -> None :
334+ """When sandbox_key is empty string, net stats should be 0 but other stats collected."""
335+ with self ._make_cgroup_context (
336+ cgroup_stat_context ,
337+ sandbox_key = "" ,
338+ ) as (ctx , mock_netstat ):
339+ results = await memory_plugin .gather_container_measures (ctx , ["cid_001" ])
340+ mock_netstat .assert_not_called ()
341+ # mem stats should be collected
342+ assert results [0 ].per_container ["cid_001" ].value == 1048576
343+ # net_rx and net_tx should be 0
344+ assert results [3 ].per_container ["cid_001" ].value == 0
345+ assert results [4 ].per_container ["cid_001" ].value == 0
346+
347+ async def test_valid_namespace_path_calls_netstat_ns (
348+ self ,
349+ memory_plugin : MemoryPlugin ,
350+ cgroup_stat_context : MagicMock ,
351+ tmp_path : Path ,
352+ ) -> None :
353+ """When namespace path exists, netstat_ns should be called and net stats collected."""
354+ valid_path = tmp_path / "valid_netns"
355+ valid_path .touch ()
356+ with self ._make_cgroup_context (
357+ cgroup_stat_context ,
358+ sandbox_key = str (valid_path ),
359+ ) as (ctx , mock_netstat ):
360+ results = await memory_plugin .gather_container_measures (ctx , ["cid_001" ])
361+ mock_netstat .assert_called ()
362+ # mem stats should be collected
363+ assert results [0 ].per_container ["cid_001" ].value == 1048576
364+ # net_rx and net_tx should have values from mock netstat_ns
365+ assert results [3 ].per_container ["cid_001" ].value == 4096
366+ assert results [4 ].per_container ["cid_001" ].value == 8192
367+
368+
237369@pytest .mark .skipif (sys .platform != "linux" , reason = "Network namespaces require Linux" )
238370class TestNetstatNsWork :
239371 """Tests for netstat_ns_work with real namespace switching."""
0 commit comments