diff --git a/cunumeric/array.py b/cunumeric/array.py index 0cdc5b402..05ffce505 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -24,7 +24,14 @@ from legate.core import Array -from .config import BinaryOpCode, UnaryOpCode, UnaryRedCode +from .config import ( + BinaryOpCode, + CuNumericOpCode, + FusedOpCode, + UnaryOpCode, + UnaryRedCode, +) +from .deferred import DeferredArray from .doc_utils import copy_docstring from .runtime import runtime from .utils import unimplemented @@ -464,7 +471,6 @@ def __ge__(self, rhs): ) # __getattribute__ - def _convert_key(self, key, stacklevel=2, first=True): # Convert any arrays stored in a key to a cuNumeric array if ( @@ -1953,7 +1959,7 @@ def perform_unary_reduction( ) return dst - # Return a new cuNumeric array for a binary operation + # Return a new legate array for a binary operation @classmethod def perform_binary_op( cls, @@ -2017,29 +2023,64 @@ def perform_binary_op( if out_dtype is None: out_dtype = cls.find_common_type(one, two) if check_types: + isDeferred = isinstance(one._thunk, DeferredArray) or isinstance( + two._thunk, DeferredArray + ) if one.dtype != two.dtype: common_type = cls.find_common_type(one, two) if one.dtype != common_type: - temp = ndarray( - shape=one.shape, - dtype=common_type, - stacklevel=(stacklevel + 1), - inputs=(one, two, where), - ) - temp._thunk.convert( - one._thunk, stacklevel=(stacklevel + 1) - ) + # remove convert ops + if isDeferred and one.shape == (): + temp = ndarray( + shape=one.shape, + dtype=common_type, + # buffer = one._thunk.array.astype(common_type), + stacklevel=(stacklevel + 1), + inputs=(one, two, where), + ) + temp._thunk = runtime.create_scalar( + one._thunk.array.astype(common_type), + common_type, + shape=one.shape, + wrap=True, + ) + else: + temp = ndarray( + shape=one.shape, + dtype=common_type, + stacklevel=(stacklevel + 1), + inputs=(one, two, where), + ) + temp._thunk.convert( + one._thunk, stacklevel=(stacklevel + 1) + ) one = temp if two.dtype != common_type: - temp = ndarray( - shape=two.shape, - dtype=common_type, - stacklevel=(stacklevel + 1), - inputs=(one, two, where), - ) - temp._thunk.convert( - two._thunk, stacklevel=(stacklevel + 1) - ) + # remove convert ops + if isDeferred and two.shape == (): + temp = ndarray( + shape=two.shape, + dtype=common_type, + # buffer = two._thunk.array.astype(common_type), + stacklevel=(stacklevel + 1), + inputs=(one, two, where), + ) + temp._thunk = runtime.create_scalar( + two._thunk.array.astype(common_type), + common_type, + shape=two.shape, + wrap=True, + ) + else: + temp = ndarray( + shape=two.shape, + dtype=common_type, + stacklevel=(stacklevel + 1), + inputs=(one, two, where), + ) + temp._thunk.convert( + two._thunk, stacklevel=(stacklevel + 1) + ) two = temp if out.dtype != out_dtype: temp = ndarray( diff --git a/cunumeric/config.py b/cunumeric/config.py index 33753eaa1..c064059eb 100644 --- a/cunumeric/config.py +++ b/cunumeric/config.py @@ -104,6 +104,12 @@ class CuNumericOpCode(IntEnum): UNARY_RED = _cunumeric.CUNUMERIC_UNARY_RED WHERE = _cunumeric.CUNUMERIC_WHERE WRITE = _cunumeric.CUNUMERIC_WRITE + FUSED_OP = _cunumeric.CUNUMERIC_FUSED_OP + + +@unique +class FusedOpCode(IntEnum): + FUSE = 1 # Match these to BinaryOpCode in binary_op_util.h @@ -197,3 +203,6 @@ class CuNumericRedopCode(IntEnum): class CuNumericTunable(IntEnum): NUM_GPUS = _cunumeric.CUNUMERIC_TUNABLE_NUM_GPUS MAX_EAGER_VOLUME = _cunumeric.CUNUMERIC_TUNABLE_MAX_EAGER_VOLUME + + +cunumeric_context.fused_id = CuNumericOpCode.FUSED_OP diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 0fb2230bb..ed797f6b7 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -254,6 +254,9 @@ def __numpy_array__(self, stacklevel=0): return np.empty(shape=self.shape, dtype=self.dtype) if self.scalar: + if not self.base._storage: + self.runtime.legate_runtime._launch_outstanding() + result = np.full( self.shape, self.get_scalar_array(stacklevel=(stacklevel + 1)), @@ -1568,7 +1571,6 @@ def binary_op( task.add_alignment(lhs, rhs1) task.add_alignment(lhs, rhs2) - task.execute() @profile diff --git a/examples/black_scholes.py b/examples/black_scholes.py index 4240d8860..02780653c 100644 --- a/examples/black_scholes.py +++ b/examples/black_scholes.py @@ -79,11 +79,17 @@ def run_black_scholes(N, D): N *= 1000 start = datetime.datetime.now() S, X, T, R, V = initialize(N, D) - call, put = black_scholes(S, X, T, R, V) - # Check the result for NaNs to synchronize before stopping timing - call_sum = np.sum(call) - put_sum = np.sum(put) - assert not math.isnan(call_sum) and not math.isnan(put_sum) + trials = 300 + ends = [None for i in range(trials)] + for i in range(trials): + call, put = black_scholes(S, X, T, R, V) + # Check the result for NaNs to synchronize before stopping timing + call_sum = np.sum(call) + put_sum = np.sum(put) + ends[i] = (call_sum, put_sum) + for i in range(trials): + call_sum, put_sum = ends[i] + assert not math.isnan(call_sum) and not math.isnan(put_sum) stop = datetime.datetime.now() delta = stop - start total = delta.total_seconds() * 1000.0 diff --git a/examples/stencil.py b/examples/stencil.py index 0db769db4..5552c2124 100644 --- a/examples/stencil.py +++ b/examples/stencil.py @@ -49,6 +49,7 @@ def run(grid, I, N): # noqa: E741 # delta = np.sum(np.absolute(work - center)) center[:] = work total = np.sum(center) + # return total return total / (N ** 2) diff --git a/examples/stencil_27.py b/examples/stencil_27.py new file mode 100644 index 000000000..28476d791 --- /dev/null +++ b/examples/stencil_27.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python + +# Copyright 2021 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import argparse +import datetime +import math + +from benchmark import run_benchmark + +import cunumeric as np + + +def initialize(N): + print("Initializing stencil grid...") + grid = np.zeros((N + 2, N + 2, N + 2)) + grid[:, :, 0] = -273.15 + grid[:, 0, :] = -273.15 + grid[0, :, :] = -273.15 + grid[:, :, -1] = 273.15 + grid[:, -1, :] = 273.15 + grid[-1, :, :] = 273.15 + + return grid + + +def run(grid, I, N): # noqa: E741 + print("Running Jacobi 27 stencil...") + + # one + g000 = grid[0:-2, 0:-2, 0:-2] + g001 = grid[0:-2, 0:-2, 1:-1] + g002 = grid[0:-2, 0:-2, 2:] + + g010 = grid[0:-2, 1:-1, 0:-2] + g011 = grid[0:-2, 1:-1, 1:-1] + g012 = grid[0:-2, 1:-1, 2:] + + g020 = grid[0:-2, 2:, 0:-2] + g021 = grid[0:-2, 2:, 1:-1] + g022 = grid[0:-2, 2:, 2:] + + # two + g100 = grid[1:-1, 0:-2, 0:-2] + g101 = grid[1:-1, 0:-2, 1:-1] + g102 = grid[1:-1, 0:-2, 2:] + + g110 = grid[1:-1, 1:-1, 0:-2] + g111 = grid[1:-1, 1:-1, 1:-1] + g112 = grid[1:-1, 1:-1, 2:] + + g120 = grid[1:-1, 2:, 0:-2] + g121 = grid[1:-1, 2:, 1:-1] + g122 = grid[1:-1, 2:, 2:] + + # three + g200 = grid[2:, 0:-2, 0:-2] + g201 = grid[2:, 0:-2, 1:-1] + g202 = grid[2:, 0:-2, 2:] + + g210 = grid[2:, 1:-1, 0:-2] + g211 = grid[2:, 1:-1, 1:-1] + g212 = grid[2:, 1:-1, 2:] + + g220 = grid[2:, 2:, 0:-2] + g221 = grid[2:, 2:, 1:-1] + g222 = grid[2:, 2:, 2:] + + for i in range(I): + g00 = g000 + g001 + g002 + g01 = g010 + g011 + g012 + g02 = g020 + g021 + g022 + g10 = g100 + g101 + g102 + g11 = g110 + g111 + g112 + g12 = g120 + g121 + g122 + g20 = g200 + g201 + g202 + g21 = g210 + g211 + g212 + g22 = g220 + g221 + g222 + + g0 = g00 + g01 + g02 + g1 = g10 + g11 + g12 + g2 = g20 + g21 + g22 + + res = g0 + g1 + g2 + work = 0.037 * res + g111[:] = work + total = np.sum(g111) + return total / (N ** 2) + + +def run_stencil(N, I, timing): # noqa: E741 + start = datetime.datetime.now() + grid = initialize(N) + average = run(grid, I, N) + # This will sync the timing because we will need to wait for the result + assert not math.isnan(average) + stop = datetime.datetime.now() + print("Average energy is %.8g" % average) + delta = stop - start + total = delta.total_seconds() * 1000.0 + if timing: + print("Elapsed Time: " + str(total) + " ms") + return total + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--iter", + type=int, + default=100, + dest="I", + help="number of iterations to run", + ) + parser.add_argument( + "-n", + "--num", + type=int, + default=100, + dest="N", + help="number of elements in one dimension", + ) + parser.add_argument( + "-t", + "--time", + dest="timing", + action="store_true", + help="perform timing", + ) + parser.add_argument( + "-b", + "--benchmark", + type=int, + default=1, + dest="benchmark", + help="number of times to benchmark this application (default 1 " + "- normal execution)", + ) + args = parser.parse_args() + run_benchmark( + run_stencil, args.benchmark, "Stencil", (args.N, args.I, args.timing) + ) diff --git a/install.py b/install.py index f9ebf43d7..eb7c21528 100755 --- a/install.py +++ b/install.py @@ -456,7 +456,7 @@ def driver(): "--clean", dest="clean_first", action=BooleanFlag, - default=True, + default=False, help="Clean before build.", ) parser.add_argument( diff --git a/src/cunumeric.mk b/src/cunumeric.mk index 668f9a0e8..6e323094c 100644 --- a/src/cunumeric.mk +++ b/src/cunumeric.mk @@ -17,6 +17,7 @@ # since we have to add the -fopenmp flag to CC_FLAGS for them GEN_CPU_SRC += cunumeric/ternary/where.cc \ cunumeric/binary/binary_op.cc \ + cunumeric/fused/fused_op.cc \ cunumeric/binary/binary_red.cc \ cunumeric/unary/scalar_unary_red.cc \ cunumeric/unary/unary_op.cc \ @@ -46,6 +47,7 @@ GEN_CPU_SRC += cunumeric/ternary/where.cc \ ifeq ($(strip $(USE_OPENMP)),1) GEN_CPU_SRC += cunumeric/ternary/where_omp.cc \ cunumeric/binary/binary_op_omp.cc \ + cunumeric/fused/fused_op_omp.cc \ cunumeric/binary/binary_red_omp.cc \ cunumeric/unary/unary_op_omp.cc \ cunumeric/unary/scalar_unary_red_omp.cc \ @@ -75,6 +77,7 @@ GEN_CPU_SRC += cunumeric/cunumeric.cc # This must always be the last file! GEN_GPU_SRC += cunumeric/ternary/where.cu \ cunumeric/binary/binary_op.cu \ + cunumeric/fused/fused_op.cu \ cunumeric/binary/binary_red.cu \ cunumeric/unary/scalar_unary_red.cu \ cunumeric/unary/unary_red.cu \ diff --git a/src/cunumeric/cunumeric_c.h b/src/cunumeric/cunumeric_c.h index 41cc6baee..5f78c5631 100644 --- a/src/cunumeric/cunumeric_c.h +++ b/src/cunumeric/cunumeric_c.h @@ -46,6 +46,7 @@ enum CuNumericOpCode { CUNUMERIC_UNARY_RED, CUNUMERIC_WHERE, CUNUMERIC_WRITE, + CUNUMERIC_FUSED_OP, }; // Match these to CuNumericRedopCode in cunumeric/config.py diff --git a/src/cunumeric/fused/fused_op.cc b/src/cunumeric/fused/fused_op.cc new file mode 100644 index 000000000..658870f6f --- /dev/null +++ b/src/cunumeric/fused/fused_op.cc @@ -0,0 +1,124 @@ +/* Copyright 2021 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "core/runtime/runtime.h" +#include "core/runtime/context.h" +#include "cunumeric/fused/fused_op.h" +#include "legion.h" +#include +#include +#include + +// namespace legate { +namespace cunumeric { + +using namespace Legion; +using namespace legate; + +/*static*/ void FusedOpTask::cpu_variant(TaskContext& context) +{ + int nOps = context.fusionMetadata.nOps; + auto opIDs = context.fusionMetadata.opIDs; + auto offsets = context.fusionMetadata.offsets; + for (int i = 0; i < nOps; i++) { + std::vector regions; + // context.runtime_->execute_task(context.context_, leaf_launcher); + + // create new context + const Legion::Task* task = (Legion::Task*)context.task_; + + // pack inputs + std::vector inputs; + auto inputStarts = context.fusionMetadata.inputStarts; + auto outputStarts = context.fusionMetadata.outputStarts; + auto offsetStarts = context.fusionMetadata.offsetStarts; + auto reductionStarts = context.fusionMetadata.reductionStarts; + unsigned nInputs = (inputStarts[i + 1] - inputStarts[i]); // want to pack this as a 32 bit uint + for (unsigned j = 0; j < nInputs; j++) { + int offsetStart = offsetStarts[i]; + int inputStart = inputStarts[i]; + int bufferID = offsets[offsetStart + j] - 1; + Store& input = context.inputs()[inputStart + bufferID]; + inputs.push_back(std::move(input)); + } + + // pack outputs + std::vector outputs; + unsigned nOutputs = + (outputStarts[i + 1] - outputStarts[i]); // want to pack this as a 32 bit uint + for (unsigned j = 0; j < nOutputs; j++) { + int offsetStart = offsetStarts[i]; + int outputStart = outputStarts[i]; + int bufferID = offsets[offsetStart + nInputs + j]; + bufferID = (-bufferID) - 1; + Store& output = context.outputs()[outputStart + bufferID]; + outputs.push_back(std::move(output)); + } + + // pack reductions + std::vector reductions; + int32_t nReductions = (reductionStarts[i + 1] - reductionStarts[i]); + for (unsigned j = 0; j < nReductions; j++) { + int offsetStart = offsetStarts[i]; + int reductionStart = reductionStarts[i]; + int bufferID = offsets[offsetStart + nInputs + nOutputs + j]; + // all buffer ids are 1 -indexed + // negative id is an output, while a positive id is an output + if (bufferID < 0) { + bufferID = (-bufferID) - 1; + Store& reduction = context.reductions()[reductionStart + bufferID]; + reductions.push_back(std::move(reduction)); + } + } + + // pack scalars + std::vector scalars; + auto scalarStarts = context.fusionMetadata.scalarStarts; + int32_t nScalars = (scalarStarts[i + 1] - scalarStarts[i]); + for (unsigned j = 0; j < nScalars; j++) { + scalars.push_back(std::move(context.scalars()[scalarStarts[i] + j])); + } + + TaskContext context3( + task, (const std::vector)regions); // inputs, outputs, scalars); + context3.inputs_ = std::move(inputs); + context3.outputs_ = std::move(outputs); + context3.scalars_ = std::move(scalars); + + // launch + auto descp = Core::cpuDescriptors.find(opIDs[i]); + auto desc = descp->second; + desc(context3); + for (unsigned j = 0; j < nOutputs; j++) { + int offsetStart = offsetStarts[i]; + int outputStart = outputStarts[i]; + int bufferID = offsets[offsetStart + nInputs + j]; + bufferID = (-bufferID) - 1; + context.outputs_[outputStart + bufferID] = std::move(context3.outputs_[j]); + } + + context3.pack_return_values(); + context.pack_return_values(); + } +} + +namespace // unnamed +{ +static void __attribute__((constructor)) register_tasks(void) { FusedOpTask::register_variants(); } +} // namespace + +//} // namespace numpy +} // namespace cunumeric diff --git a/src/cunumeric/fused/fused_op.cu b/src/cunumeric/fused/fused_op.cu new file mode 100644 index 000000000..8aa9a36ab --- /dev/null +++ b/src/cunumeric/fused/fused_op.cu @@ -0,0 +1,117 @@ +/* Copyright 2021 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/fused/fused_op.h" +#include "cunumeric/cuda_help.h" + +//namespace legate { +namespace cunumeric { + +using namespace Legion; +using namespace legate; + +/*static*/ void FusedOpTask::gpu_variant(TaskContext& context){ + + int nOps = context.fusionMetadata.nOps; + auto opIDs = context.fusionMetadata.opIDs; + auto offsets = context.fusionMetadata.offsets; + for (int i=0; i regions; + //create new context + const Legion::Task* task = (Legion::Task*) context.task_; + + //pack inputs + std::vector inputs; + auto inputStarts = context.fusionMetadata.inputStarts; + auto outputStarts = context.fusionMetadata.outputStarts; + auto offsetStarts = context.fusionMetadata.offsetStarts; + auto reductionStarts = context.fusionMetadata.reductionStarts; + unsigned nInputs = (inputStarts[i+1]-inputStarts[i]); //want to pack this as a 32 bit uint + for (unsigned j = 0; j outputs; + unsigned nOutputs = (outputStarts[i+1]-outputStarts[i]); //want to pack this as a 32 bit uint + for (unsigned j = 0; j reductions; + int32_t nReductions = (reductionStarts[i+1]-reductionStarts[i]); + for (unsigned j = 0; j scalars; + for (unsigned j = 0; j) regions);// inputs, outputs, scalars); + context3.inputs_ = std::move(inputs); + context3.outputs_ = std::move(outputs); + context3.reductions_ = std::move(reductions); + context3.scalars_ = std::move(scalars); + + //launch + auto descp = Core::gpuDescriptors.find(opIDs[i]); + + auto desc = descp->second; + desc(context3); + for (unsigned j = 0; j { + public: + static const int TASK_ID = CUNUMERIC_FUSED_OP; + + public: + static void cpu_variant(legate::TaskContext& context); +#ifdef LEGATE_USE_OPENMP + static void omp_variant(legate::TaskContext& context); +#endif +#ifdef LEGATE_USE_CUDA + static void gpu_variant(legate::TaskContext& context); +#endif +}; + +} // namespace cunumeric diff --git a/src/cunumeric/fused/fused_op_omp.cc b/src/cunumeric/fused/fused_op_omp.cc new file mode 100644 index 000000000..22e13997a --- /dev/null +++ b/src/cunumeric/fused/fused_op_omp.cc @@ -0,0 +1,108 @@ +/* Copyright 2021 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/fused/fused_op.h" + +namespace cunumeric { + +using namespace Legion; +using namespace legate; + +/*static*/ void FusedOpTask::omp_variant(TaskContext& context) +{ + int nOps = context.fusionMetadata.nOps; + auto opIDs = context.fusionMetadata.opIDs; + auto offsets = context.fusionMetadata.offsets; + for (int i = 0; i < nOps; i++) { + std::vector regions; + // create new context + const Legion::Task* task = (Legion::Task*)context.task_; + + // pack inputs + std::vector inputs; + auto inputStarts = context.fusionMetadata.inputStarts; + auto outputStarts = context.fusionMetadata.outputStarts; + auto offsetStarts = context.fusionMetadata.offsetStarts; + auto reductionStarts = context.fusionMetadata.reductionStarts; + unsigned nInputs = (inputStarts[i + 1] - inputStarts[i]); // want to pack this as a 32 bit uint + for (unsigned j = 0; j < nInputs; j++) { + int offsetStart = offsetStarts[i]; + int inputStart = inputStarts[i]; + int bufferID = offsets[offsetStart + j] - 1; + Store& input = context.inputs()[inputStart + bufferID]; + inputs.push_back(std::move(input)); + } + + // pack outputs + std::vector outputs; + unsigned nOutputs = + (outputStarts[i + 1] - outputStarts[i]); // want to pack this as a 32 bit uint + for (unsigned j = 0; j < nOutputs; j++) { + int offsetStart = offsetStarts[i]; + int outputStart = outputStarts[i]; + int bufferID = offsets[offsetStart + nInputs + j]; + bufferID = (-bufferID) - 1; + Store& output = context.outputs()[outputStart + bufferID]; + outputs.push_back(std::move(output)); + } + + // pack reductions + std::vector reductions; + int32_t nReductions = (reductionStarts[i + 1] - reductionStarts[i]); + for (unsigned j = 0; j < nReductions; j++) { + int offsetStart = offsetStarts[i]; + int reductionStart = reductionStarts[i]; + int bufferID = offsets[offsetStart + nInputs + nOutputs + j]; + // all buffer ids are 1 -indexed + // negative id is an output, while a positive id is an output + if (bufferID < 0) { + bufferID = (-bufferID) - 1; + Store& reduction = context.reductions()[reductionStart + bufferID]; + reductions.push_back(std::move(reduction)); + } + } + + // pack scalars + auto scalarStarts = context.fusionMetadata.scalarStarts; + int32_t nScalars = (scalarStarts[i + 1] - scalarStarts[i]); + std::vector scalars; + for (unsigned j = 0; j < nScalars; j++) { + scalars.push_back(std::move(context.scalars()[scalarStarts[i] + j])); + } + + TaskContext context3( + task, (const std::vector)regions); // inputs, outputs, scalars); + context3.inputs_ = std::move(inputs); + context3.outputs_ = std::move(outputs); + context3.reductions_ = std::move(reductions); + context3.scalars_ = std::move(scalars); + + // launch + auto descp = Core::ompDescriptors.find(opIDs[i]); + + auto desc = descp->second; + desc(context3); + for (unsigned j = 0; j < nOutputs; j++) { + int offsetStart = offsetStarts[i]; + int outputStart = outputStarts[i]; + int bufferID = offsets[offsetStart + nInputs + j]; + bufferID = (-bufferID) - 1; + context.outputs_[outputStart + bufferID] = std::move(context3.outputs_[j]); + } + } +} + +} // namespace cunumeric diff --git a/tests/tensordot.py b/tests/tensordot.py index bfebfb51b..1b137eb75 100644 --- a/tests/tensordot.py +++ b/tests/tensordot.py @@ -19,13 +19,14 @@ def test(ty): + rtol = 2e-03 if ty == np.float16 else 1e-05 a = num.random.rand(3, 5, 4).astype(ty) b = num.random.rand(4, 5, 3).astype(ty) cn = np.tensordot(a, b, axes=1) c = num.tensordot(a, b, axes=1) - assert np.allclose(cn, c) + assert np.allclose(cn, c, rtol=rtol) a = num.random.rand(3, 5, 4).astype(ty) b = num.random.rand(5, 4, 3).astype(ty) @@ -33,7 +34,7 @@ def test(ty): cn = np.tensordot(a, b) c = num.tensordot(a, b) - assert np.allclose(cn, c) + assert np.allclose(cn, c, rtol=rtol) a = num.arange(60.0).reshape((3, 4, 5)).astype(ty) b = num.arange(24.0).reshape((4, 3, 2)).astype(ty) @@ -41,7 +42,7 @@ def test(ty): cn = np.tensordot(a, b, axes=([1, 0], [0, 1])) c = num.tensordot(a, b, axes=([1, 0], [0, 1])) - assert np.allclose(cn, c) + assert np.allclose(cn, c, rtol=rtol) a = num.random.rand(5, 4).astype(ty) b = num.random.rand(4, 5).astype(ty) @@ -49,7 +50,7 @@ def test(ty): cn = np.tensordot(a, b, axes=1) c = num.tensordot(a, b, axes=1) - assert np.allclose(cn, c) + assert np.allclose(cn, c, rtol=rtol) a = num.random.rand(5, 4).astype(ty) b = num.random.rand(5, 4).astype(ty) @@ -57,7 +58,7 @@ def test(ty): cn = np.tensordot(a, b) c = num.tensordot(a, b) - assert np.allclose(cn, c) + assert np.allclose(cn, c, rtol=rtol) if __name__ == "__main__":