Skip to content

Exert tighter control over sizes of random DAGs #256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,24 +1331,36 @@ def test_random_dag_against_numpy(ctx_factory):
ctx = ctx_factory()
cq = cl.CommandQueue(ctx)

from testlib import RandomDAGContext, make_random_dag
from testlib import RandomDAGContext, make_random_dag, make_random_dag_rec
axis_len = 5
from warnings import filterwarnings, catch_warnings
with catch_warnings():
# We'd like to know if Numpy divides by zero.
filterwarnings("error")

for i in range(50):
print(i) # progress indicator for somewhat slow test
size = 20

for i in range(50):
seed = 120 + i

additional_generators = [
(500, lambda rdagc, size: make_random_dag_rec(rdagc, size=size))
]

rdagc_pt = RandomDAGContext(np.random.default_rng(seed=seed),
additional_generators=additional_generators,
axis_len=axis_len, use_numpy=False)
rdagc_np = RandomDAGContext(np.random.default_rng(seed=seed),
additional_generators=additional_generators,
axis_len=axis_len, use_numpy=True)

ref_result = make_random_dag(rdagc_np)
dag = make_random_dag(rdagc_pt)
ref_result, ref_size = make_random_dag(rdagc_np, size=size)
dag, dag_size = make_random_dag(rdagc_pt, size=size)

assert ref_size == dag_size

print(f"Step {i} {seed=} {dag_size=}")

from pytato.transform import materialize_with_mpms
dict_named_arys = pt.DictOfNamedArrays({"result": dag})
dict_named_arys = materialize_with_mpms(dict_named_arys)
Expand Down
126 changes: 81 additions & 45 deletions test/testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,12 @@ def assert_allclose_to_numpy(expr: Array, queue: cl.CommandQueue,
# {{{ random DAG generation

class RandomDAGContext:
def __init__(self, rng: np.random.Generator, axis_len: int, use_numpy: bool,
def __init__(
self, rng: np.random.Generator, axis_len: int, use_numpy: bool,
additional_generators: Optional[Sequence[
Tuple[int, Callable[[RandomDAGContext], Array]]]] = None) -> None:
Tuple[int, Callable[[RandomDAGContext, int],
Tuple[Array, int]]]]] = None
) -> None:
"""
:param additional_generators: A sequence of tuples
``(fake_probability, gen_func)``, where *fake_probability* is
Expand Down Expand Up @@ -132,68 +135,89 @@ def make_random_reshape(
operator.pow, "maximum", "minimum"]


def make_random_dag_inner(rdagc: RandomDAGContext) -> Any:
def make_random_binary_op(rdagc: RandomDAGContext,
op1: Any, *, size: int) -> Tuple[Any, int]:
rng = rdagc.rng

op2, op2_size = make_random_dag_rec(rdagc, size=size - 1)
m = min(op1.ndim, op2.ndim)
naxes = rng.integers(m, m+2)

# Introduce a few new 1-long axes to test broadcasting.
op1 = op1.reshape(*make_random_reshape(rdagc, op1.shape, naxes))
op2 = op2.reshape(*make_random_reshape(rdagc, op2.shape, naxes))

# type ignore because rng.choice doesn't have broad enough type
# annotation to represent choosing callables.
which_op = rng.choice(_BINOPS) # type: ignore[arg-type]

if which_op is operator.pow:
op1 = abs(op1)

# Squeeze because all axes need to be of rdagc.axis_len, and we've
# just inserted a few new 1-long axes. Those need to go before we
# return.
if which_op in ["maximum", "minimum"]:
result = getattr(rdagc.np, which_op)(op1, op2)
else:
result = which_op(op1, op2)

return rdagc.np.squeeze(result), op2_size + 1


def make_random_dag_choice(rdagc: RandomDAGContext, *, size: int) -> Tuple[Any, int]:
rng = rdagc.rng

max_prob_hardcoded = 1500
additional_prob = sum(
fake_prob
for fake_prob, _func in rdagc.additional_generators)

if size <= 1:
return make_random_constant(rdagc, naxes=rng.integers(1, 3)), 1

while True:
v = rng.integers(0, max_prob_hardcoded + additional_prob)

if v < 600:
return make_random_constant(rdagc, naxes=rng.integers(1, 3))
return make_random_constant(rdagc, naxes=rng.integers(1, 3)), 1

elif v < 1000:
op1 = make_random_dag(rdagc)
op2 = make_random_dag(rdagc)
m = min(op1.ndim, op2.ndim)
naxes = rng.integers(m, m+2)

# Introduce a few new 1-long axes to test broadcasting.
op1 = op1.reshape(*make_random_reshape(rdagc, op1.shape, naxes))
op2 = op2.reshape(*make_random_reshape(rdagc, op2.shape, naxes))

# type ignore because rng.choice doesn't have broad enough type
# annotation to represent choosing callables.
which_op = rng.choice(_BINOPS) # type: ignore[arg-type]

if which_op is operator.pow:
op1 = abs(op1)

# Squeeze because all axes need to be of rdagc.axis_len, and we've
# just inserted a few new 1-long axes. Those need to go before we
# return.
if which_op in ["maximum", "minimum"]:
return rdagc.np.squeeze(getattr(rdagc.np, which_op)(op1, op2))
else:
return rdagc.np.squeeze(which_op(op1, op2))
op1, op1_size = make_random_dag_rec(rdagc, size=size)

if op1_size > size:
continue

result, op2_and_root_size = make_random_binary_op(
rdagc, op1, size=size - op1_size)

return result, op1_size + op2_and_root_size

elif v < 1075:
op1 = make_random_dag(rdagc)
op2 = make_random_dag(rdagc)
op1, op1_size = make_random_dag_rec(rdagc, size=size)
op2, op2_size = make_random_dag_rec(rdagc, size=size-op1_size-1)
if op1.ndim <= 1 and op2.ndim <= 1:
continue

return op1 @ op2
return op1 @ op2, op1_size + 1 + op2_size

elif v < 1275:
if not rdagc.past_results:
continue
return rdagc.past_results[rng.integers(0, len(rdagc.past_results))]

return rdagc.past_results[rng.integers(0, len(rdagc.past_results))], 0

elif v < max_prob_hardcoded:
result = make_random_dag(rdagc)
result, res_size = make_random_dag_rec(rdagc, size=size-1)
return rdagc.np.transpose(
result,
tuple(rng.permuted(list(range(result.ndim)))))
tuple(rng.permuted(list(range(result.ndim))))), res_size + 1

else:
base_prob = max_prob_hardcoded
for fake_prob, gen_func in rdagc.additional_generators:
if base_prob <= v < base_prob + fake_prob:
return gen_func(rdagc)
return gen_func(rdagc, size)

base_prob += fake_prob

Expand All @@ -207,15 +231,9 @@ def make_random_dag_inner(rdagc: RandomDAGContext) -> Any:
# FIXME: include DictOfNamedArrays


def make_random_dag(rdagc: RandomDAGContext) -> Any:
"""Return a :class:`pytato.Array` or a :class:`numpy.ndarray`
(cf. :attr:`RandomDAGContext.use_numpy`) that is the result of a random
(cf. :attr:`RandomDAGContext.rng`) array computation. All axes
of the array are of length :attr:`RandomDAGContext.axis_len` (there is
at least one axis, but arbitrarily more may be present).
"""
def make_random_dag_rec(rdagc: RandomDAGContext, *, size: int) -> Tuple[Any, int]:
rng = rdagc.rng
result = make_random_dag_inner(rdagc)
result, actual_size = make_random_dag_choice(rdagc, size=size)

if result.ndim > 2:
v = rng.integers(0, 2)
Expand All @@ -225,21 +243,39 @@ def make_random_dag(rdagc: RandomDAGContext) -> Any:
subscript[rng.integers(0, result.ndim)] = int(
rng.integers(0, rdagc.axis_len))

return result[tuple(subscript)]
return result[tuple(subscript)], actual_size+1

elif v == 1:
# reduce away an axis

# FIXME do reductions other than sum?
return rdagc.np.sum(
result, axis=int(rng.integers(0, result.ndim)))
result, axis=int(rng.integers(0, result.ndim))), actual_size+1

else:
raise AssertionError()

rdagc.past_results.append(result)

return result
return result, actual_size


def make_random_dag(rdagc: RandomDAGContext, *, size: int) -> Tuple[Any, int]:
"""Return a :class:`pytato.Array` or a :class:`numpy.ndarray`
(cf. :attr:`RandomDAGContext.use_numpy`) that is the result of a random
(cf. :attr:`RandomDAGContext.rng`) array computation. All axes
of the array are of length :attr:`RandomDAGContext.axis_len` (there is
at least one axis, but arbitrarily more may be present).
"""
result, actual_size = make_random_dag_rec(rdagc, size=size)

while actual_size < size:
result, op2_and_root_size = make_random_binary_op(
rdagc, result, size=size - actual_size)

actual_size = actual_size + op2_and_root_size

return result, actual_size

# }}}

Expand Down