Skip to content
15 changes: 14 additions & 1 deletion redbot/core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import namedtuple, OrderedDict
from datetime import datetime
from importlib.machinery import ModuleSpec
import importlib.util
from pathlib import Path
from typing import (
Optional,
Expand Down Expand Up @@ -1705,8 +1706,20 @@ async def load_extension(self, spec: ModuleSpec):
if name in self.extensions:
raise errors.PackageAlreadyLoaded(spec)

lib = spec.loader.load_module()
# Check if module already exists in sys.modules (after refresh by _cleanup_and_refresh_modules)
if spec.name in sys.modules:
# Use the refreshed module from sys.modules
lib = sys.modules[spec.name]
else:
# First-time load: use modern import approach
lib = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = lib
spec.loader.exec_module(lib)

if not hasattr(lib, "setup"):
# Remove module from sys.modules to prevent pollution
if lib.__name__ in sys.modules:
del sys.modules[lib.__name__]
del lib
raise discord.ClientException(f"extension {name} does not have a setup function")

Expand Down
113 changes: 109 additions & 4 deletions redbot/core/core_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextlib
import datetime
import importlib
import importlib.util
import itertools
import keyword
import logging
Expand Down Expand Up @@ -185,7 +186,6 @@ async def _load(self, pkg_names: Iterable[str]) -> Dict[str, Union[List[str], Di

async for spec, name in AsyncIter(pkg_specs, steps=10):
try:
self._cleanup_and_refresh_modules(spec.name)
await bot.load_extension(spec)
except errors.PackageAlreadyLoaded:
alreadyloaded_packages.append(name)
Expand Down Expand Up @@ -251,7 +251,27 @@ def maybe_reload(new_name):
except KeyError:
pass
else:
importlib._bootstrap._exec(lib.__spec__, lib)
# Create a new spec from the file to get updated source code
if hasattr(lib.__spec__, "origin") and lib.__spec__.origin:
try:
# Create fresh spec and reload source
new_spec = importlib.util.spec_from_file_location(
lib.__spec__.name, lib.__spec__.origin
)
if new_spec and new_spec.loader:
# Update the module's spec to the fresh one
lib.__spec__ = new_spec
# Execute with the fresh spec to load updated source
new_spec.loader.exec_module(lib)
else:
# Fallback to original method if spec creation fails
importlib._bootstrap._exec(lib.__spec__, lib)
except Exception:
# Fallback to original method if anything fails
importlib._bootstrap._exec(lib.__spec__, lib)
else:
# Fallback to original method for non-file modules
importlib._bootstrap._exec(lib.__spec__, lib)

# noinspection PyTypeChecker
modules = itertools.accumulate(splitted, "{}.{}".format)
Expand All @@ -264,7 +284,7 @@ def maybe_reload(new_name):
if name == module_name or name.startswith(f"{module_name}.")
}
for child_name, lib in children.items():
importlib._bootstrap._exec(lib.__spec__, lib)
maybe_reload(child_name)

async def _unload(self, pkg_names: Iterable[str]) -> Dict[str, List[str]]:
"""
Expand All @@ -290,7 +310,36 @@ async def _unload(self, pkg_names: Iterable[str]) -> Dict[str, List[str]]:

for name in pkg_names:
if name in bot.extensions:
# Find the extension module and clear its .pyc cache before unloading
if name in sys.modules:
module = sys.modules[name]
if hasattr(module, "__file__") and module.__file__:
# Clear .pyc cache by removing __pycache__ directory
import os
import shutil

pycache_dir = os.path.join(os.path.dirname(module.__file__), "__pycache__")
if os.path.exists(pycache_dir):
try:
shutil.rmtree(pycache_dir)
except (OSError, IOError):
# Ignore errors removing cache directory
pass

await bot.unload_extension(name)

# Manually remove related modules from sys.modules to force fresh reload
modules_to_remove = []
for module_name in sys.modules:
if module_name == name or module_name.startswith(f"{name}."):
modules_to_remove.append(module_name)

for module_name in modules_to_remove:
del sys.modules[module_name]

# Clear import caches to ensure fresh loading
importlib.invalidate_caches()

await bot.remove_loaded_package(name)
unloaded_packages.append(name)
else:
Expand All @@ -314,9 +363,64 @@ async def _reload(
dict
Dictionary with keys as returned by `CoreLogic._load()`
"""
# Handle case where pkg_names might be a single string instead of a sequence
if isinstance(pkg_names, str):
pkg_names = [pkg_names]

# Store RPC handler names before unload to ensure they're re-registered
rpc_handlers_to_restore = {}

for pkg_name in pkg_names:
if pkg_name in self.bot.extensions:
# Find all RPC handlers for this package
pkg_rpc_handlers = []
for cog_name, methods in self.bot.rpc_handlers.items():
for method in methods:
# Check if this method belongs to the package being reloaded
if hasattr(method, "__self__") and hasattr(method.__self__, "__module__"):
method_module = method.__self__.__module__
if method_module == pkg_name or method_module.startswith(
f"{pkg_name}."
):
pkg_rpc_handlers.append(method)

rpc_handlers_to_restore[pkg_name] = pkg_rpc_handlers

await self._unload(pkg_names)

return await self._load(pkg_names)
result = await self._load(pkg_names)

# Verify that RPC handlers were properly re-registered for reloaded packages
for pkg_name in pkg_names:
if pkg_name in result.get("loaded_packages", []):
# Force refresh of RPC method references in case they weren't updated
for cog_name, methods in list(self.bot.rpc_handlers.items()):
updated_methods = []
for method in methods:
if hasattr(method, "__self__") and hasattr(method.__self__, "__module__"):
method_module = method.__self__.__module__
if method_module == pkg_name or method_module.startswith(
f"{pkg_name}."
):
# Get the fresh method reference from the reloaded cog
cog = method.__self__
method_name = method.__name__
if hasattr(cog, method_name):
fresh_method = getattr(cog, method_name)
# Re-register with fresh method reference
self.bot.rpc.remove_method(method)
self.bot.rpc.add_method(fresh_method)
updated_methods.append(fresh_method)
else:
updated_methods.append(method)
else:
updated_methods.append(method)
else:
updated_methods.append(method)

self.bot.rpc_handlers[cog_name] = updated_methods

return result

async def _name(self, name: Optional[str] = None) -> str:
"""
Expand Down Expand Up @@ -5767,6 +5871,7 @@ async def rpc_load(self, request):
if spec is None:
raise LookupError("No such cog found.")

print(f"DEBUG: RPC loading cog {cog_name}, spec name: {spec.name}")
self._cleanup_and_refresh_modules(spec.name)

await self.bot.load_extension(spec)
Expand Down
11 changes: 10 additions & 1 deletion redbot/pytest/rpc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest
from redbot.core._rpc import RPC, RPCMixin
from redbot.core.core_commands import CoreLogic

from unittest.mock import MagicMock

__all__ = ["rpc", "rpcmixin", "cog", "existing_func", "existing_multi_func"]
__all__ = ["rpc", "rpcmixin", "cog", "existing_func", "existing_multi_func", "core_logic"]


@pytest.fixture()
Expand Down Expand Up @@ -51,3 +52,11 @@ def existing_multi_func(rpc, cog):
rpc.add_multi_method(*funcs)

return funcs


@pytest.fixture(scope="function")
async def core_logic(red):
"""Create a CoreLogic instance for testing RPC handlers."""
# Ensure RPC system is initialized before creating CoreLogic
await red.rpc._pre_login()
return CoreLogic(red)
Loading
Loading