Skip to content

Commit a8f0bfa

Browse files
PaliCfacebook-github-bot
authored andcommitted
Allow torch._C to be recognized a module in torch.package in pytorch/multipy (#91)
Summary: Pull Request resolved: #91 This pr addresses #82 and #44. A C extension module behaves a bit differently than a normal python package as it does not contain a `__path__` attribute. However, these modules still have information about their submodules. This PR also checks if a module is a C extension module and checks if the module we are looking for is in it's children. For example, if we are importing `torch._C._nn` we check if the parent `torch._C` is a C extension module if necessary, and then check if `torch._C._nn` is a proper child of `torch._C`. The corresponding PR in torch.package is pytorch/pytorch#80917 Reviewed By: d4l3k Differential Revision: D37630606 fbshipit-source-id: 59b5b7c291ccecef9f598a93c33719d348a339f5
1 parent 80fd3b3 commit a8f0bfa

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

multipy/package/package_importer_no_torch.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import builtins
88
import importlib
9+
import importlib.machinery
910
import inspect
1011
import io
1112
import linecache
@@ -378,18 +379,45 @@ def _install_on_parent(self, parent: str, name: str, module: types.ModuleType):
378379
def _do_find_and_load(self, name):
379380
path = None
380381
parent = name.rpartition(".")[0]
382+
module_name_no_parent = name.rpartition(".")[-1]
381383
if parent:
382384
if parent not in self.modules:
383385
self._gcd_import(parent)
384386
# Crazy side-effects!
385387
if name in self.modules:
386388
return self.modules[name]
387389
parent_module = self.modules[parent]
390+
388391
try:
389392
path = parent_module.__path__ # type: ignore[attr-defined] # noqa
393+
390394
except AttributeError:
391-
msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent)
392-
raise ModuleNotFoundError(msg, name=name) from None
395+
# when we attempt to import a package only containing pybinded files,
396+
# the parent directory isn't always a package as defined by python,
397+
# so we search if the package is actually there or not before calling the error.
398+
if isinstance(
399+
parent_module.__loader__,
400+
importlib.machinery.ExtensionFileLoader,
401+
):
402+
if name not in self.extern_modules:
403+
msg = (
404+
_ERR_MSG
405+
+ "; {!r} is a c extension module which was not externed. C extension modules \
406+
need to be externed by the PackageExporter in order to be used as we do not support interning them.}."
407+
).format(name, name)
408+
raise ModuleNotFoundError(msg, name=name) from None
409+
if not isinstance(
410+
parent_module.__dict__.get(module_name_no_parent),
411+
types.ModuleType,
412+
):
413+
msg = (
414+
_ERR_MSG
415+
+ "; {!r} is a c extension package which does not contain {!r}."
416+
).format(name, parent, name)
417+
raise ModuleNotFoundError(msg, name=name) from None
418+
else:
419+
msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent)
420+
raise ModuleNotFoundError(msg, name=name) from None
393421

394422
module = self._load_module(name, parent)
395423

multipy/test/package/test_dependency_api.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
# Owner(s): ["oncall: package/deploy"]
88

99
import importlib
10+
1011
from io import BytesIO
1112
from sys import version_info
1213
from textwrap import dedent
1314
from unittest import skipIf
1415

16+
import torch.nn
17+
1518
from multipy.package import (
1619
EmptyMatchError,
1720
Importer,
@@ -369,6 +372,29 @@ def test_repackage_mocked_module(self):
369372
with self.assertRaises(NotImplementedError):
370373
foo2.package_a.get_something()
371374

375+
def test_externing_c_extension(self):
376+
"""Externing c extensions modules should allow us to still access them especially those found in torch._C."""
377+
378+
buffer = BytesIO()
379+
# The C extension module in question is F.gelu which comes from torch._C._nn
380+
model = torch.nn.TransformerEncoderLayer(
381+
d_model=64,
382+
nhead=2,
383+
dim_feedforward=64,
384+
dropout=1.0,
385+
batch_first=True,
386+
activation="gelu",
387+
norm_first=True,
388+
)
389+
with PackageExporter(buffer) as e:
390+
e.extern("torch.**")
391+
e.intern("**")
392+
393+
e.save_pickle("model", "model.pkl", model)
394+
buffer.seek(0)
395+
imp = PackageImporter(buffer)
396+
imp.load_pickle("model", "model.pkl")
397+
372398

373399
class TestDependencyAPINoTorch(TestDependencyAPI):
374400
def __init__(self, *args, **kwargs):

0 commit comments

Comments
 (0)