Skip to content

Commit 7fef48e

Browse files
authored
Merge pull request #58 from smups/dev
Update CI to use new release branch
2 parents 59113c0 + 940cb0c commit 7fef48e

5 files changed

Lines changed: 66 additions & 17 deletions

File tree

.github/workflows/CI.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ name: CI
33
on:
44
push:
55
branches:
6-
- main
7-
- master
6+
- release
87
tags:
98
- '*'
109
pull_request:
@@ -114,9 +113,9 @@ jobs:
114113
strategy:
115114
matrix:
116115
platform:
117-
- runner: macos-12
116+
- runner: macos-13
118117
target: x86_64
119-
- runner: macos-14
118+
- runner: macos-latest
120119
target: aarch64
121120
steps:
122121
- uses: actions/checkout@v4
@@ -161,7 +160,7 @@ jobs:
161160
pypi-publish:
162161
name: Release
163162
runs-on: ubuntu-latest
164-
if: "startsWith(github.ref_name, 'master')"
163+
if: "startsWith(github.ref_name, 'release')"
165164
needs: [linux, windows, macos, sdist]
166165
permissions:
167166
id-token: write

python/inflatox/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@
2525

2626
from . import consistency_conditions
2727
from . import background
28+
from .libinflx_rs import log_info, log_warn
2829

2930
__all__ = [
3031
"CompilationArtifact",
3132
"Compiler",
3233
"InflationModel",
3334
"InflationModelBuilder",
3435
"consistency_conditions",
36+
"log_info",
37+
"log_warn",
3538
"background",
3639
"__version__",
3740
]

python/inflatox/compiler.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -405,10 +405,13 @@ def _generate_c_function(
405405
if not self.silent:
406406
print(f"Found {len(cse_list[0])} common subexpressions")
407407
for cse_symbol, cse_definition in cse_list[0]:
408-
out += f" const double {printer.doprint(cse_symbol)} = {printer.doprint(cse_definition)};\n"
409-
out += f" return {printer.doprint(cse_list[1])}{';\n}\n'}"
408+
out += f" const double {printer.doprint(cse_symbol)} = {printer.doprint(cse_definition)};"
409+
out += "\n"
410+
out += f" return {printer.doprint(cse_list[1])};"
411+
out += "\n}\n"
410412
else:
411-
out += f" return {printer.doprint(body)}{';\n}\n'}"
413+
out += f" return {printer.doprint(body)};"
414+
out += "\n}\n"
412415
return out + "\n"
413416

414417
def _generate_c_function_for_vector(
@@ -424,15 +427,17 @@ def _generate_c_function_for_vector(
424427
if not self.silent:
425428
print(f"Found {len(cse_list[0])} common subexpressions")
426429
for cse_symbol, cse_definition in cse_list[0]:
427-
out += f" const double {printer.doprint(cse_symbol)} = {printer.doprint(cse_definition)};\n"
430+
out += f" const double {printer.doprint(cse_symbol)} = {printer.doprint(cse_definition)};"
431+
out += "\n"
428432
for output_expr in cse_list[1]:
429433
ordered_output_expr.append(output_expr)
430434
else:
431435
ordered_output_expr = vector
432436

433437
# Assign each element of the output vector to a component of the vector
434438
for i, output_cmp in enumerate(ordered_output_expr):
435-
out += f" v_out[{i}] = {printer.doprint(output_cmp)};\n"
439+
out += f" v_out[{i}] = {printer.doprint(output_cmp)};"
440+
out += "\n"
436441
out += " return;\n}\n\n"
437442

438443
return out
@@ -446,7 +451,8 @@ def _generate_c_function_for_inner_prod(self, printer: CInflatoxPrinter) -> str:
446451
if self.cse:
447452
cse_list = sympy.cse(flattened_metric, symbols=self._new_cse_generator(), list=True)
448453
for cse_symbol, cse_definition in cse_list[0]:
449-
out += f" const double {printer.doprint(cse_symbol)} = {printer.doprint(cse_definition)};\n"
454+
out += f" const double {printer.doprint(cse_symbol)} = {printer.doprint(cse_definition)};"
455+
out += "\n"
450456
flattened_metric = cse_list[1]
451457

452458
# Write function for outer product
@@ -458,9 +464,11 @@ def _generate_c_function_for_inner_prod(self, printer: CInflatoxPrinter) -> str:
458464
symbol_str = printer.doprint(flattened_metric[n])
459465
if symbol_str == "0" or symbol_str == "0.0":
460466
continue
461-
out += f" const double g{i}{j} = {symbol_str};\n"
467+
out += f" const double g{i}{j} = {symbol_str};"
468+
out += "\n"
462469
return_expr += f" + (g{i}{j} * v1[{i}] * v2[{j}])"
463-
out += f" return {return_expr};\n}}\n\n"
470+
out += f" return {return_expr};"
471+
out += "\n}\n\n"
464472
return out
465473

466474
def _generate_c_file(self):

python/inflatox/symbolic.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import numpy as np
2121
import sympy
2222
from interruptingcow import timeout
23+
import os
24+
from . import libinflx_rs
2325
from joblib import Parallel, cpu_count, delayed
2426
from sympy.simplify import sqrtdenest
2527
from sympy.vector import Gradient
@@ -140,6 +142,9 @@ def new(
140142
This class will automatically derive names for the derivatives of the fields (these are
141143
used by other inflatox components when solving the equations of motion).
142144
145+
Simplifications are currently disabled on windows due to their time-out system relying on
146+
UNIX signals.
147+
143148
### Args
144149
- `fields` (`list[sympy.Symbol]`): list of sympy symbols that should be
145150
interpreted as fields (coordinates on the scalar manifold).
@@ -156,7 +161,9 @@ def new(
156161
- `assertions` (`bool`, *optional*): if False, expensive intermediate
157162
assertions will be disabled. Defaults to False.
158163
- `simplify` (`bool`, *optional*): When `True` `sympy`'s simplify method will be used.
159-
Defaults to `True`.
164+
Due to the time-out system for simplifications relying on UNIX signals, simplifications
165+
cannot be turned on on Windows. If you set `simplify` to `True` on Windows it will be
166+
ignored. Defaults to `True` (except on Windows platforms).
160167
- `simplify_timeout` (`float`, *optional*): time-out time in seconds for simplification
161168
steps.
162169
@@ -167,6 +174,21 @@ def new(
167174
if init_sympy_printing:
168175
sympy.init_printing()
169176

177+
if simplify == True and os.name == "nt":
178+
libinflx_rs.log_warn(
179+
"cannot use simplifications on Windows. Continuing without simplifications."
180+
)
181+
return cls(
182+
fields=fields,
183+
field_metric=field_metric,
184+
potential=potential,
185+
model_name=model_name if model_name is not None else "generic model",
186+
silent=silent,
187+
assertions=assertions,
188+
simplify=False,
189+
simplify_timeout=0.0,
190+
)
191+
170192
return cls(
171193
fields=fields,
172194
field_metric=field_metric,
@@ -208,7 +230,7 @@ def __init__(
208230

209231
def simplify_expr(self, expr: sympy.Expr) -> sympy.Expr:
210232
"""simplifies expression"""
211-
if not self.simplify:
233+
if not self.simplify or os.name == "nt":
212234
return expr
213235
try:
214236
with timeout(self.simplify_timeout, exception=SimplificationTimeOut):
@@ -220,7 +242,7 @@ def simplify_expr(self, expr: sympy.Expr) -> sympy.Expr:
220242

221243
def expand_and_factor_expr(self, expr: sympy.Expr) -> sympy.Expr:
222244
"""`sympy.expand` followed by `sympy.factor` with a time-out"""
223-
if not self.simplify:
245+
if not self.simplify or os.name == "nt":
224246
return expr
225247
try:
226248
with timeout(self.simplify_timeout, exception=SimplificationTimeOut):
@@ -232,7 +254,7 @@ def expand_and_factor_expr(self, expr: sympy.Expr) -> sympy.Expr:
232254

233255
def sqrt_and_denest_expr(self, expr: sympy.Expr) -> sympy.Expr:
234256
"""returns denested square root of expr (with timeout)"""
235-
if not self.simplify:
257+
if not self.simplify or os.name == "nt":
236258
return sympy.sqrt(expr)
237259
try:
238260
with timeout(self.simplify_timeout, exception=SimplificationTimeOut):
@@ -527,6 +549,8 @@ def calc_gradient_square(self) -> sympy.Expr:
527549
for a in range(dim):
528550
for b in range(dim):
529551
out += metric_inv[a, b] * gradient[a] * gradient[b]
552+
if os.name == "nt":
553+
return self.simplify_expr(out)
530554
try:
531555
with timeout(self.simplify_timeout, exception=SimplificationTimeOut):
532556
out = sympy.factor(sympy.expand(out))
@@ -599,6 +623,9 @@ def gramm_schmidt(
599623
xdoty = self.inner_prod(x, y)
600624
for a in range(dim):
601625
y[a] -= xdoty * x[a]
626+
# Disable simplification if we're on Windows
627+
if os.name == "nt":
628+
return [self.simplify_expr(yi) for yi in self.normalize(y)]
602629
try:
603630
with timeout(self.simplify_timeout, exception=SimplificationTimeOut):
604631
y = [sympy.factor(sympy.expand(yi)) for yi in y]

src/lib.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ fn libinflx_rs(_py: Python<'_>, pymod: &Bound<PyModule>) -> PyResult<()> {
7171
use anguelova::*;
7272
pymod.add_class::<InflatoxPyDyLib>()?;
7373
pymod.add_function(wrap_pyfunction!(open_inflx_dylib, pymod)?)?;
74+
pymod.add_function(wrap_pyfunction!(log_info, pymod)?)?;
75+
pymod.add_function(wrap_pyfunction!(log_warn, pymod)?)?;
7476

7577
pymod.add_function(wrap_pyfunction!(flag_quantum_dif_py, pymod)?)?;
7678
pymod.add_function(wrap_pyfunction!(consistency_only, pymod)?)?;
@@ -89,6 +91,16 @@ fn libinflx_rs(_py: Python<'_>, pymod: &Bound<PyModule>) -> PyResult<()> {
8991
Ok(())
9092
}
9193

94+
#[pyfunction]
95+
fn log_info<'py>(msg: Bound<'py, pyo3::types::PyString>) {
96+
eprintln!("{}{}", *BADGE_INFO, msg.to_string());
97+
}
98+
99+
#[pyfunction]
100+
fn log_warn<'py>(msg: Bound<'py, pyo3::types::PyString>) {
101+
eprintln!("{}{}", *BADGE_WARN, msg.to_string());
102+
}
103+
92104
#[pyclass]
93105
/// Python wrapper for `InflatoxDyLib`
94106
pub struct InflatoxPyDyLib(pub InflatoxDylib);

0 commit comments

Comments
 (0)