diff --git a/redbot/core/bot.py b/redbot/core/bot.py index 359b65e3657..6dc4bd97b75 100644 --- a/redbot/core/bot.py +++ b/redbot/core/bot.py @@ -12,6 +12,7 @@ from collections import namedtuple, OrderedDict from datetime import datetime from importlib.machinery import ModuleSpec +import importlib.util from pathlib import Path from typing import ( Optional, @@ -1705,8 +1706,20 @@ async def load_extension(self, spec: ModuleSpec): if name in self.extensions: raise errors.PackageAlreadyLoaded(spec) - lib = spec.loader.load_module() + # Check if module already exists in sys.modules (after refresh by _cleanup_and_refresh_modules) + if spec.name in sys.modules: + # Use the refreshed module from sys.modules + lib = sys.modules[spec.name] + else: + # First-time load: use modern import approach + lib = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = lib + spec.loader.exec_module(lib) + if not hasattr(lib, "setup"): + # Remove module from sys.modules to prevent pollution + if lib.__name__ in sys.modules: + del sys.modules[lib.__name__] del lib raise discord.ClientException(f"extension {name} does not have a setup function") diff --git a/redbot/core/core_commands.py b/redbot/core/core_commands.py index 93cacdeb840..1c36caf33df 100644 --- a/redbot/core/core_commands.py +++ b/redbot/core/core_commands.py @@ -2,6 +2,7 @@ import contextlib import datetime import importlib +import importlib.util import itertools import keyword import logging @@ -185,7 +186,6 @@ async def _load(self, pkg_names: Iterable[str]) -> Dict[str, Union[List[str], Di async for spec, name in AsyncIter(pkg_specs, steps=10): try: - self._cleanup_and_refresh_modules(spec.name) await bot.load_extension(spec) except errors.PackageAlreadyLoaded: alreadyloaded_packages.append(name) @@ -251,7 +251,27 @@ def maybe_reload(new_name): except KeyError: pass else: - importlib._bootstrap._exec(lib.__spec__, lib) + # Create a new spec from the file to get updated source code + if hasattr(lib.__spec__, "origin") and lib.__spec__.origin: + try: + # Create fresh spec and reload source + new_spec = importlib.util.spec_from_file_location( + lib.__spec__.name, lib.__spec__.origin + ) + if new_spec and new_spec.loader: + # Update the module's spec to the fresh one + lib.__spec__ = new_spec + # Execute with the fresh spec to load updated source + new_spec.loader.exec_module(lib) + else: + # Fallback to original method if spec creation fails + importlib._bootstrap._exec(lib.__spec__, lib) + except Exception: + # Fallback to original method if anything fails + importlib._bootstrap._exec(lib.__spec__, lib) + else: + # Fallback to original method for non-file modules + importlib._bootstrap._exec(lib.__spec__, lib) # noinspection PyTypeChecker modules = itertools.accumulate(splitted, "{}.{}".format) @@ -264,7 +284,7 @@ def maybe_reload(new_name): if name == module_name or name.startswith(f"{module_name}.") } for child_name, lib in children.items(): - importlib._bootstrap._exec(lib.__spec__, lib) + maybe_reload(child_name) async def _unload(self, pkg_names: Iterable[str]) -> Dict[str, List[str]]: """ @@ -290,7 +310,36 @@ async def _unload(self, pkg_names: Iterable[str]) -> Dict[str, List[str]]: for name in pkg_names: if name in bot.extensions: + # Find the extension module and clear its .pyc cache before unloading + if name in sys.modules: + module = sys.modules[name] + if hasattr(module, "__file__") and module.__file__: + # Clear .pyc cache by removing __pycache__ directory + import os + import shutil + + pycache_dir = os.path.join(os.path.dirname(module.__file__), "__pycache__") + if os.path.exists(pycache_dir): + try: + shutil.rmtree(pycache_dir) + except (OSError, IOError): + # Ignore errors removing cache directory + pass + await bot.unload_extension(name) + + # Manually remove related modules from sys.modules to force fresh reload + modules_to_remove = [] + for module_name in sys.modules: + if module_name == name or module_name.startswith(f"{name}."): + modules_to_remove.append(module_name) + + for module_name in modules_to_remove: + del sys.modules[module_name] + + # Clear import caches to ensure fresh loading + importlib.invalidate_caches() + await bot.remove_loaded_package(name) unloaded_packages.append(name) else: @@ -314,9 +363,64 @@ async def _reload( dict Dictionary with keys as returned by `CoreLogic._load()` """ + # Handle case where pkg_names might be a single string instead of a sequence + if isinstance(pkg_names, str): + pkg_names = [pkg_names] + + # Store RPC handler names before unload to ensure they're re-registered + rpc_handlers_to_restore = {} + + for pkg_name in pkg_names: + if pkg_name in self.bot.extensions: + # Find all RPC handlers for this package + pkg_rpc_handlers = [] + for cog_name, methods in self.bot.rpc_handlers.items(): + for method in methods: + # Check if this method belongs to the package being reloaded + if hasattr(method, "__self__") and hasattr(method.__self__, "__module__"): + method_module = method.__self__.__module__ + if method_module == pkg_name or method_module.startswith( + f"{pkg_name}." + ): + pkg_rpc_handlers.append(method) + + rpc_handlers_to_restore[pkg_name] = pkg_rpc_handlers + await self._unload(pkg_names) - return await self._load(pkg_names) + result = await self._load(pkg_names) + + # Verify that RPC handlers were properly re-registered for reloaded packages + for pkg_name in pkg_names: + if pkg_name in result.get("loaded_packages", []): + # Force refresh of RPC method references in case they weren't updated + for cog_name, methods in list(self.bot.rpc_handlers.items()): + updated_methods = [] + for method in methods: + if hasattr(method, "__self__") and hasattr(method.__self__, "__module__"): + method_module = method.__self__.__module__ + if method_module == pkg_name or method_module.startswith( + f"{pkg_name}." + ): + # Get the fresh method reference from the reloaded cog + cog = method.__self__ + method_name = method.__name__ + if hasattr(cog, method_name): + fresh_method = getattr(cog, method_name) + # Re-register with fresh method reference + self.bot.rpc.remove_method(method) + self.bot.rpc.add_method(fresh_method) + updated_methods.append(fresh_method) + else: + updated_methods.append(method) + else: + updated_methods.append(method) + else: + updated_methods.append(method) + + self.bot.rpc_handlers[cog_name] = updated_methods + + return result async def _name(self, name: Optional[str] = None) -> str: """ @@ -5767,6 +5871,7 @@ async def rpc_load(self, request): if spec is None: raise LookupError("No such cog found.") + print(f"DEBUG: RPC loading cog {cog_name}, spec name: {spec.name}") self._cleanup_and_refresh_modules(spec.name) await self.bot.load_extension(spec) diff --git a/redbot/pytest/rpc.py b/redbot/pytest/rpc.py index db189e5de0e..bfeaef8952a 100644 --- a/redbot/pytest/rpc.py +++ b/redbot/pytest/rpc.py @@ -1,9 +1,10 @@ import pytest from redbot.core._rpc import RPC, RPCMixin +from redbot.core.core_commands import CoreLogic from unittest.mock import MagicMock -__all__ = ["rpc", "rpcmixin", "cog", "existing_func", "existing_multi_func"] +__all__ = ["rpc", "rpcmixin", "cog", "existing_func", "existing_multi_func", "core_logic"] @pytest.fixture() @@ -51,3 +52,11 @@ def existing_multi_func(rpc, cog): rpc.add_multi_method(*funcs) return funcs + + +@pytest.fixture(scope="function") +async def core_logic(red): + """Create a CoreLogic instance for testing RPC handlers.""" + # Ensure RPC system is initialized before creating CoreLogic + await red.rpc._pre_login() + return CoreLogic(red) diff --git a/tests/core/test_rpc.py b/tests/core/test_rpc.py index 12e3c6b5ee4..b270a447121 100644 --- a/tests/core/test_rpc.py +++ b/tests/core/test_rpc.py @@ -1,7 +1,23 @@ import pytest +import pytest_asyncio +import tempfile +import textwrap +import asyncio +import json +import aiohttp +from pathlib import Path from redbot.pytest.rpc import * from redbot.core._rpc import get_name +from redbot.core.core_commands import CoreLogic + + +@pytest_asyncio.fixture(scope="function") +async def core_logic(red): + """Create a CoreLogic instance for testing RPC handlers.""" + # Ensure RPC system is initialized before creating CoreLogic + await red.rpc._pre_login() + return CoreLogic(red) def test_get_name(cog): @@ -98,3 +114,495 @@ def test_rpcmixin_unregister(rpcmixin, cog): if cogname in rpcmixin.rpc_handlers: assert cog.cofunc not in rpcmixin.rpc_handlers[cogname] + + +@pytest.fixture() +def test_cog_module(): + """Helper fixture to generate temporary test cog modules.""" + + class TestCogHelper: + def __init__(self): + self.temp_dir = None + self.module_path = None + + def create_module(self, cog_name, handlers, tmpdir): + """Create a temporary cog module with specified handlers.""" + self.temp_dir = Path(str(tmpdir)) + + # Create cog package directory + cog_package_dir = self.temp_dir / cog_name + cog_package_dir.mkdir(exist_ok=True) + + # Create __init__.py with setup function that imports from the main module + init_file = cog_package_dir / "__init__.py" + + # Generate RPC handler registration calls for setup function + rpc_registrations_setup = [] + for handler_name in handlers.keys(): + rpc_registrations_setup.append(f" bot.register_rpc_handler(cog.{handler_name})") + + init_content = textwrap.dedent( + f""" +from .{cog_name} import {cog_name.title()} + +async def setup(bot): + cog = {cog_name.title()}(bot) + await bot.add_cog(cog) +{chr(10).join(rpc_registrations_setup)} +""" + ).strip() + init_file.write_text(init_content, encoding="utf-8") + + # Create main cog module file + self.module_path = cog_package_dir / f"{cog_name}.py" + + # Generate handler methods + handler_methods = [] + for handler_name, return_value in handlers.items(): + handler_methods.append( + f""" + async def {handler_name}(self): + return "{return_value}" +""" + ) + + # Generate RPC handler registration calls + rpc_registrations = [] + for handler_name in handlers.keys(): + rpc_registrations.append( + f" self.bot.register_rpc_handler(self.{handler_name})" + ) + + # Generate cog class code + cog_code = textwrap.dedent( + f""" +from redbot.core import commands + +class {cog_name.title()}(commands.Cog): + def __init__(self, bot): + self.bot = bot +{''.join(handler_methods)} +""" + ).strip() + + self.module_path.write_text(cog_code, encoding="utf-8") + return self.module_path + + def update_handlers(self, handlers): + """Update handler return values in the module file.""" + if not self.module_path or not self.module_path.exists(): + raise RuntimeError("Module not created yet") + + content = self.module_path.read_text(encoding="utf-8") + + # Update each handler's return value + for handler_name, new_return_value in handlers.items(): + # Find and replace the return statement for this handler + import re + + pattern = rf'(async def {handler_name}\(self\):\s*return )"[^"]*"' + replacement = rf'\1"{new_return_value}"' + content = re.sub(pattern, replacement, content) + + self.module_path.write_text(content, encoding="utf-8") + + return TestCogHelper() + + +@pytest.mark.asyncio +async def test_rpc_handler_updates_on_reload(red, core_logic, test_cog_module, tmpdir): + """Test that RPC handlers execute new code after reload.""" + # Create test cog module + cog_name = "testcog" + handlers = {"test_handler": "version_1"} + test_cog_module.create_module(cog_name, handlers, tmpdir) + + # Add temp directory to cog paths + await red._cog_mgr.add_path(Path(str(tmpdir))) + + try: + # Load the cog via RPC + await core_logic._load([cog_name]) + + # Verify cog is loaded + assert cog_name in red.extensions + + # Call RPC handler and verify initial behavior + handler_name = f"{cog_name.upper()}__TEST_HANDLER" + assert handler_name in red.rpc._rpc.methods + + # Capture the original handler reference before reload + original_handler = red.rpc._rpc.methods[handler_name].method + result = await original_handler() + assert result == "version_1" + + # Modify the module file to return different value + test_cog_module.update_handlers({"test_handler": "version_2"}) + + # Reload the cog via RPC + await core_logic._reload([cog_name]) + + # Verify cog is still loaded + assert cog_name in red.extensions + + # Capture the new handler reference after reload + new_handler = red.rpc._rpc.methods[handler_name].method + + # Verify the handler reference itself updated (not just the return value) + assert original_handler is not new_handler, "Handler reference should update after reload" + assert id(original_handler) != id( + new_handler + ), "Handler object identity should differ after reload" + + # Call RPC handler and verify it now executes new code + result = await new_handler() + assert result == "version_2", "RPC handler should execute new code after reload" + + finally: + # Clean up + if cog_name in red.extensions: + await core_logic._unload([cog_name]) + await red._cog_mgr.remove_path(Path(str(tmpdir)).resolve()) + + +@pytest.mark.asyncio +async def test_rpc_reload_flow_matches_unload_load(red, core_logic, test_cog_module, tmpdir): + """Test that _reload() matches _unload() + _load() behavior.""" + # Create test cog module + cog_name = "flowtest" + handlers = {"flow_handler": "initial"} + test_cog_module.create_module(cog_name, handlers, tmpdir) + + # Add temp directory to cog paths + await red._cog_mgr.add_path(Path(str(tmpdir))) + + try: + # Load the cog and verify initial behavior + await core_logic._load([cog_name]) + handler_name = f"{cog_name.upper()}__FLOW_HANDLER" + + # Capture original handler reference + original_handler = red.rpc._rpc.methods[handler_name].method + result = await original_handler() + assert result == "initial" + + # Modify the module file + test_cog_module.update_handlers({"flow_handler": "after_reload"}) + + # Test _reload() behavior + await core_logic._reload([cog_name]) + + # Capture new handler reference and verify it updated + new_handler = red.rpc._rpc.methods[handler_name].method + assert original_handler is not new_handler, "Handler reference should update after _reload" + + result = await new_handler() + assert result == "after_reload" + + # Unload the cog + await core_logic._unload([cog_name]) + assert cog_name not in red.extensions + assert handler_name not in red.rpc._rpc.methods + + # Modify the module file again + test_cog_module.update_handlers({"flow_handler": "after_manual_load"}) + + # Load again and verify it loads the latest version + await core_logic._load([cog_name]) + + # Capture final handler reference and verify it's different from reload handler + final_handler = red.rpc._rpc.methods[handler_name].method + assert ( + new_handler is not final_handler + ), "Handler reference should update after manual load" + + result = await final_handler() + assert result == "after_manual_load" + + finally: + # Clean up + if cog_name in red.extensions: + await core_logic._unload([cog_name]) + await red._cog_mgr.remove_path(Path(str(tmpdir)).resolve()) + + +@pytest.mark.asyncio +async def test_multiple_rpc_handlers_update_on_reload(red, core_logic, test_cog_module, tmpdir): + """Test that multiple RPC handlers all update on reload.""" + # Create test cog module with multiple handlers + cog_name = "multitest" + handlers = {"handler_one": "one_v1", "handler_two": "two_v1", "handler_three": "three_v1"} + test_cog_module.create_module(cog_name, handlers, tmpdir) + + # Add temp directory to cog paths + await red._cog_mgr.add_path(Path(str(tmpdir))) + + try: + # Load the cog + await core_logic._load([cog_name]) + + # Verify all handlers work with initial values + handler_names = [ + f"{cog_name.upper()}__HANDLER_ONE", + f"{cog_name.upper()}__HANDLER_TWO", + f"{cog_name.upper()}__HANDLER_THREE", + ] + + for handler_name in handler_names: + assert handler_name in red.rpc._rpc.methods + + # Capture original handler references before reload + original_handlers = [red.rpc._rpc.methods[name].method for name in handler_names] + + result_one = await original_handlers[0]() + result_two = await original_handlers[1]() + result_three = await original_handlers[2]() + + assert result_one == "one_v1" + assert result_two == "two_v1" + assert result_three == "three_v1" + + # Modify all handlers in the module file + new_handlers = { + "handler_one": "one_v2", + "handler_two": "two_v2", + "handler_three": "three_v2", + } + test_cog_module.update_handlers(new_handlers) + + # Reload the cog + await core_logic._reload([cog_name]) + + # Capture new handler references after reload + new_handlers = [red.rpc._rpc.methods[name].method for name in handler_names] + + # Verify all handler references updated (not just return values) + for i, (original, new) in enumerate(zip(original_handlers, new_handlers)): + assert original is not new, f"Handler {i+1} reference should update after reload" + assert id(original) != id( + new + ), f"Handler {i+1} object identity should differ after reload" + + # Verify all handlers now execute new code + result_one = await new_handlers[0]() + result_two = await new_handlers[1]() + result_three = await new_handlers[2]() + + assert result_one == "one_v2", "First handler should execute new code" + assert result_two == "two_v2", "Second handler should execute new code" + assert result_three == "three_v2", "Third handler should execute new code" + + finally: + # Clean up + if cog_name in red.extensions: + await core_logic._unload([cog_name]) + await red._cog_mgr.remove_path(Path(str(tmpdir)).resolve()) + + +@pytest.mark.asyncio +async def test_rpc_reload_flow_via_rpc_interface(red, core_logic, test_cog_module, tmpdir): + """Test RPC reload using the actual rpc_reload() method instead of _reload().""" + # Create test cog module + cog_name = "rpctest" + handlers = {"rpc_handler": "rpc_v1"} + test_cog_module.create_module(cog_name, handlers, tmpdir) + + # Add temp directory to cog paths + await red._cog_mgr.add_path(Path(str(tmpdir))) + + try: + # Load the cog initially + await core_logic._load([cog_name]) + + # Verify cog is loaded and handler works + assert cog_name in red.extensions + handler_name = f"{cog_name.upper()}__RPC_HANDLER" + assert handler_name in red.rpc._rpc.methods + + # Capture original handler reference before reload + original_handler = red.rpc._rpc.methods[handler_name].method + result = await original_handler() + assert result == "rpc_v1" + + # Modify the module file to return different value + test_cog_module.update_handlers({"rpc_handler": "rpc_v2"}) + + # Create a minimal request-like object with params + class MockRequest: + def __init__(self, params): + self.params = params + + request = MockRequest([cog_name]) + + # Reload using the actual rpc_reload method via RPC interface + reload_handler = red.rpc._rpc.methods["CORELOGIC__RELOAD"].method + await reload_handler([cog_name]) + + # Verify cog is still loaded + assert cog_name in red.extensions + + # Verify RPC handler name remains present + assert handler_name in red.rpc._rpc.methods + + # Capture new handler reference after rpc_reload + new_handler = red.rpc._rpc.methods[handler_name].method + + # Verify the handler reference itself updated via rpc_reload + assert ( + original_handler is not new_handler + ), "Handler reference should update after rpc_reload" + assert id(original_handler) != id( + new_handler + ), "Handler object identity should differ after rpc_reload" + + # Verify invoking the handler now returns updated value + result = await new_handler() + assert result == "rpc_v2", "RPC handler should execute new code after rpc_reload" + + finally: + # Clean up + if cog_name in red.extensions: + await core_logic._unload([cog_name]) + await red._cog_mgr.remove_path(Path(str(tmpdir)).resolve()) + + +@pytest.mark.asyncio +@pytest.mark.skip_ci # Skip in CI if network tests are restricted +async def test_rpc_reload_via_websocket_endpoint_smoke_test( + red, core_logic, test_cog_module, tmpdir +): + """Smoke test that validates RPC reload through actual WebSocket endpoint to mirror real usage.""" + import os + + # Skip test if running in CI environment or if explicitly disabled + if os.environ.get("CI") or os.environ.get("SKIP_NETWORK_TESTS"): + pytest.skip("Skipping network test in CI or when SKIP_NETWORK_TESTS is set") + + # Create test cog module + cog_name = "httptest" + handlers = {"http_handler": "http_v1"} + test_cog_module.create_module(cog_name, handlers, tmpdir) + + # Add temp directory to cog paths + await red._cog_mgr.add_path(Path(str(tmpdir))) + + # Start RPC server on ephemeral port (0 = random available port) + from aiohttp import web + + app = web.Application() + app.router.add_route("*", "/jsonrpc", red.rpc._rpc.handle_request) + + runner = web.AppRunner(app) + await runner.setup() + + # Use ephemeral port for testing + site = web.TCPSite(runner, "localhost", 0) + await site.start() + + # Get the actual port assigned + server_port = site._server.sockets[0].getsockname()[1] + server_url = f"http://localhost:{server_port}/jsonrpc" + + try: + # Load the cog initially + await core_logic._load([cog_name]) + + # Verify cog is loaded + assert cog_name in red.extensions + handler_name = f"{cog_name.upper()}__HTTP_HANDLER" + assert handler_name in red.rpc._rpc.methods + + # Capture original handler reference before reload + original_handler = red.rpc._rpc.methods[handler_name].method + + async with aiohttp.ClientSession() as session: + # Test 1: Call the handler via WebSocket RPC to verify initial behavior + async with session.ws_connect(f"ws://localhost:{server_port}/jsonrpc") as ws: + # Send JSON-RPC request + await ws.send_json( + {"jsonrpc": "2.0", "method": handler_name, "params": [], "id": 1} + ) + + # Receive response + result = await ws.receive_json() + assert result["result"] == "http_v1" + assert "error" not in result + + # Modify the module file to return different value + test_cog_module.update_handlers({"http_handler": "http_v2"}) + + # Test 2: Call rpc_reload via WebSocket RPC + async with session.ws_connect(f"ws://localhost:{server_port}/jsonrpc") as ws: + # Send reload request + await ws.send_json( + { + "jsonrpc": "2.0", + "method": "CORELOGIC__RELOAD", + "params": [cog_name], + "id": 2, + } + ) + + # Receive response + result = await ws.receive_json() + assert ( + "error" not in result + ), f"RPC reload failed: {result.get('error', 'Unknown error')}" + + # Wait a moment for any async reload operations to complete + await asyncio.sleep(0.1) + + # Verify cog is still loaded after reload + assert cog_name in red.extensions + assert handler_name in red.rpc._rpc.methods + + # Capture new handler reference after reload + new_handler = red.rpc._rpc.methods[handler_name].method + + # Test 3: Call the handler again via WebSocket RPC to verify new behavior + async with session.ws_connect(f"ws://localhost:{server_port}/jsonrpc") as ws: + # Send request for updated handler + await ws.send_json( + {"jsonrpc": "2.0", "method": handler_name, "params": [], "id": 3} + ) + + # Receive response + result = await ws.receive_json() + assert ( + result["result"] == "http_v2" + ), "Handler should execute new code after WebSocket RPC reload" + assert "error" not in result + + # Test 4: Verify we can call the handler multiple times with consistent results + async with session.ws_connect(f"ws://localhost:{server_port}/jsonrpc") as ws: + for i in range(3): + # Send consistency test request + await ws.send_json( + {"jsonrpc": "2.0", "method": handler_name, "params": [], "id": 10 + i} + ) + + # Receive response + result = await ws.receive_json() + assert ( + result["result"] == "http_v2" + ), f"Handler should be consistent on call {i+1}" + + # Test 5: Verify error handling for non-existent methods + async with session.ws_connect(f"ws://localhost:{server_port}/jsonrpc") as ws: + # Send request for non-existent method + await ws.send_json( + {"jsonrpc": "2.0", "method": "NONEXISTENT__METHOD", "params": [], "id": 4} + ) + + # Receive response + result = await ws.receive_json() + assert "error" in result, "Should return error for non-existent method" + + finally: + # Clean up server + await runner.cleanup() + + # Clean up cog and paths + if cog_name in red.extensions: + await core_logic._unload([cog_name]) + await red._cog_mgr.remove_path(Path(str(tmpdir)).resolve())