Skip to content

Commit 7bef38f

Browse files
authored
Fix for massively parallel performance regression (#720)
* Fix for massively parallel performance regression
1 parent ebbed56 commit 7bef38f

File tree

1 file changed

+85
-89
lines changed

1 file changed

+85
-89
lines changed

pyop2/compilation.py

+85-89
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _check_hashes(x, y, datatype):
6363

6464

6565
def set_default_compiler(compiler):
66-
"""Set the PyOP2 default compiler, globally.
66+
"""Set the PyOP2 default compiler, globally over COMM_WORLD.
6767
6868
:arg compiler: String with name or path to compiler executable
6969
OR a subclass of the Compiler class
@@ -85,66 +85,73 @@ def set_default_compiler(compiler):
8585
)
8686

8787

88-
def sniff_compiler(exe):
88+
def sniff_compiler(exe, comm=mpi.COMM_WORLD):
8989
"""Obtain the correct compiler class by calling the compiler executable.
9090
9191
:arg exe: String with name or path to compiler executable
92+
:arg comm: Comm over which we want to determine the compiler type
9293
:returns: A compiler class
9394
"""
94-
try:
95-
output = subprocess.run(
96-
[exe, "--version"],
97-
stdout=subprocess.PIPE,
98-
stderr=subprocess.PIPE,
99-
check=True,
100-
encoding="utf-8"
101-
).stdout
102-
except (subprocess.CalledProcessError, UnicodeDecodeError):
103-
output = ""
104-
105-
# Find the name of the compiler family
106-
if output.startswith("gcc") or output.startswith("g++"):
107-
name = "GNU"
108-
elif output.startswith("clang"):
109-
name = "clang"
110-
elif output.startswith("Apple LLVM") or output.startswith("Apple clang"):
111-
name = "clang"
112-
elif output.startswith("icc"):
113-
name = "Intel"
114-
elif "Cray" in output.split("\n")[0]:
115-
# Cray is more awkward eg:
116-
# Cray clang version 11.0.4 (<some_hash>)
117-
# gcc (GCC) 9.3.0 20200312 (Cray Inc.)
118-
name = "Cray"
119-
else:
120-
name = "unknown"
121-
122-
# Set the compiler instance based on the platform (and architecture)
123-
if sys.platform.find("linux") == 0:
124-
if name == "Intel":
125-
compiler = LinuxIntelCompiler
126-
elif name == "GNU":
127-
compiler = LinuxGnuCompiler
128-
elif name == "clang":
129-
compiler = LinuxClangCompiler
130-
elif name == "Cray":
131-
compiler = LinuxCrayCompiler
95+
compiler = None
96+
if comm.rank == 0:
97+
# Note:
98+
# Sniffing compiler for very large numbers of MPI ranks is
99+
# expensive so we do this on one rank and broadcast
100+
try:
101+
output = subprocess.run(
102+
[exe, "--version"],
103+
stdout=subprocess.PIPE,
104+
stderr=subprocess.PIPE,
105+
check=True,
106+
encoding="utf-8"
107+
).stdout
108+
except (subprocess.CalledProcessError, UnicodeDecodeError):
109+
output = ""
110+
111+
# Find the name of the compiler family
112+
if output.startswith("gcc") or output.startswith("g++"):
113+
name = "GNU"
114+
elif output.startswith("clang"):
115+
name = "clang"
116+
elif output.startswith("Apple LLVM") or output.startswith("Apple clang"):
117+
name = "clang"
118+
elif output.startswith("icc"):
119+
name = "Intel"
120+
elif "Cray" in output.split("\n")[0]:
121+
# Cray is more awkward eg:
122+
# Cray clang version 11.0.4 (<some_hash>)
123+
# gcc (GCC) 9.3.0 20200312 (Cray Inc.)
124+
name = "Cray"
132125
else:
133-
compiler = AnonymousCompiler
134-
elif sys.platform.find("darwin") == 0:
135-
if name == "clang":
136-
machine = platform.uname().machine
137-
if machine == "arm64":
138-
compiler = MacClangARMCompiler
139-
elif machine == "x86_64":
140-
compiler = MacClangCompiler
141-
elif name == "GNU":
142-
compiler = MacGNUCompiler
126+
name = "unknown"
127+
128+
# Set the compiler instance based on the platform (and architecture)
129+
if sys.platform.find("linux") == 0:
130+
if name == "Intel":
131+
compiler = LinuxIntelCompiler
132+
elif name == "GNU":
133+
compiler = LinuxGnuCompiler
134+
elif name == "clang":
135+
compiler = LinuxClangCompiler
136+
elif name == "Cray":
137+
compiler = LinuxCrayCompiler
138+
else:
139+
compiler = AnonymousCompiler
140+
elif sys.platform.find("darwin") == 0:
141+
if name == "clang":
142+
machine = platform.uname().machine
143+
if machine == "arm64":
144+
compiler = MacClangARMCompiler
145+
elif machine == "x86_64":
146+
compiler = MacClangCompiler
147+
elif name == "GNU":
148+
compiler = MacGNUCompiler
149+
else:
150+
compiler = AnonymousCompiler
143151
else:
144152
compiler = AnonymousCompiler
145-
else:
146-
compiler = AnonymousCompiler
147-
return compiler
153+
154+
return comm.bcast(compiler, 0)
148155

149156

150157
class Compiler(ABC):
@@ -178,8 +185,8 @@ class Compiler(ABC):
178185
_debugflags = ()
179186

180187
def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, comm=None):
181-
# Get compiler version ASAP since it is used in __repr__
182-
self.sniff_compiler_version()
188+
# Set compiler version ASAP since it is used in __repr__
189+
self.version = None
183190

184191
self._extra_compiler_flags = tuple(extra_compiler_flags)
185192
self._extra_linker_flags = tuple(extra_linker_flags)
@@ -190,6 +197,7 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co
190197
# Compilation communicators are reference counted on the PyOP2 comm
191198
self.pcomm = mpi.internal_comm(comm, self)
192199
self.comm = mpi.compilation_comm(self.pcomm, self)
200+
self.sniff_compiler_version()
193201

194202
def __repr__(self):
195203
return f"<{self._name} compiler, version {self.version or 'unknown'}>"
@@ -238,23 +246,28 @@ def sniff_compiler_version(self, cpp=False):
238246
:arg cpp: If set to True will use the C++ compiler rather than
239247
the C compiler to determine the version number.
240248
"""
249+
# Note:
250+
# Sniffing the compiler version for very large numbers of
251+
# MPI ranks is expensive
241252
exe = self.cxx if cpp else self.cc
242-
self.version = None
243-
# `-dumpversion` is not sufficient to get the whole version string (for some compilers),
244-
# but other compilers do not implement `-dumpfullversion`!
245-
for dumpstring in ["-dumpfullversion", "-dumpversion"]:
246-
try:
247-
output = subprocess.run(
248-
[exe, dumpstring],
249-
stdout=subprocess.PIPE,
250-
stderr=subprocess.PIPE,
251-
check=True,
252-
encoding="utf-8"
253-
).stdout
254-
self.version = Version(output)
255-
break
256-
except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion):
257-
continue
253+
version = None
254+
if self.comm.rank == 0:
255+
# `-dumpversion` is not sufficient to get the whole version string (for some compilers),
256+
# but other compilers do not implement `-dumpfullversion`!
257+
for dumpstring in ["-dumpfullversion", "-dumpversion"]:
258+
try:
259+
output = subprocess.run(
260+
[exe, dumpstring],
261+
stdout=subprocess.PIPE,
262+
stderr=subprocess.PIPE,
263+
check=True,
264+
encoding="utf-8"
265+
).stdout
266+
version = Version(output)
267+
break
268+
except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion):
269+
continue
270+
self.version = self.comm.bcast(version, 0)
258271

259272
@property
260273
def bugfix_cflags(self):
@@ -448,23 +461,6 @@ class LinuxGnuCompiler(Compiler):
448461
_optflags = ("-march=native", "-O3", "-ffast-math")
449462
_debugflags = ("-O0", "-g")
450463

451-
def sniff_compiler_version(self, cpp=False):
452-
super(LinuxGnuCompiler, self).sniff_compiler_version()
453-
if self.version >= Version("7.0"):
454-
try:
455-
# gcc-7 series only spits out patch level on dumpfullversion.
456-
exe = self.cxx if cpp else self.cc
457-
output = subprocess.run(
458-
[exe, "-dumpfullversion"],
459-
stdout=subprocess.PIPE,
460-
stderr=subprocess.PIPE,
461-
check=True,
462-
encoding="utf-8"
463-
).stdout
464-
self.version = Version(output)
465-
except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion):
466-
pass
467-
468464
@property
469465
def bugfix_cflags(self):
470466
"""Flags to work around bugs in compilers."""
@@ -596,7 +592,7 @@ def __init__(self, code, argtypes):
596592
exe = configuration["cxx"] or "mpicxx"
597593
else:
598594
exe = configuration["cc"] or "mpicc"
599-
compiler = sniff_compiler(exe)
595+
compiler = sniff_compiler(exe, comm)
600596
dll = compiler(cppargs, ldargs, cpp=cpp, comm=comm).get_so(code, extension)
601597

602598
if isinstance(jitmodule, GlobalKernel):

0 commit comments

Comments
 (0)