Skip to content

Commit f27ec5e

Browse files
authored
More helpful error message in vjp transform + concate bug (#543)
* more helpful message in vjp transform * fix concatenate on mismatch dims * typo * typo
1 parent f30e633 commit f27ec5e

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

mlx/ops.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -669,26 +669,27 @@ array concatenate(
669669
int axis,
670670
StreamOrDevice s /* = {} */) {
671671
if (arrays.size() == 0) {
672-
throw std::invalid_argument("No arrays provided for concatenation");
672+
throw std::invalid_argument(
673+
"[concatenate] No arrays provided for concatenation");
673674
}
674675

675676
// Normalize the given axis
676677
auto ax = axis < 0 ? axis + arrays[0].ndim() : axis;
677678
if (ax < 0 || ax >= arrays[0].ndim()) {
678679
std::ostringstream msg;
679-
msg << "Invalid axis (" << axis << ") passed to concatenate"
680+
msg << "[concatenate] Invalid axis (" << axis << ") passed to concatenate"
680681
<< " for array with shape " << arrays[0].shape() << ".";
681682
throw std::invalid_argument(msg.str());
682683
}
683684

684685
auto throw_invalid_shapes = [&]() {
685686
std::ostringstream msg;
686-
msg << "All the input array dimensions must match exactly except"
687-
<< " for the concatenation axis. However, the provided shapes are ";
687+
msg << "[concatenate] All the input array dimensions must match exactly "
688+
<< "except for the concatenation axis. However, the provided shapes are ";
688689
for (auto& a : arrays) {
689690
msg << a.shape() << ", ";
690691
}
691-
msg << "and the concatenation axis is " << axis;
692+
msg << "and the concatenation axis is " << axis << ".";
692693
throw std::invalid_argument(msg.str());
693694
};
694695

@@ -697,6 +698,13 @@ array concatenate(
697698
// Make the output shape and validate that all arrays have the same shape
698699
// except for the concatenation axis.
699700
for (auto& a : arrays) {
701+
if (a.ndim() != shape.size()) {
702+
std::ostringstream msg;
703+
msg << "[concatenate] All the input arrays must have the same number of "
704+
<< "dimensions. However, got arrays with dimensions " << shape.size()
705+
<< " and " << a.ndim() << ".";
706+
throw std::invalid_argument(msg.str());
707+
}
700708
for (int i = 0; i < a.ndim(); i++) {
701709
if (i == ax) {
702710
continue;

mlx/transforms.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,21 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
337337
}
338338
}
339339
if (cotan_index >= cotans.size()) {
340-
throw std::invalid_argument(
341-
"[vjp] Number of outputs with gradient does not match number of cotangents.");
340+
std::ostringstream msg;
341+
msg << "[vjp] Number of outputs to compute gradients for ("
342+
<< outputs.size() << ") does not match number of cotangents ("
343+
<< cotans.size() << ").";
344+
throw std::invalid_argument(msg.str());
342345
}
343346
if (out.shape() != cotans[cotan_index].shape()) {
344-
throw std::invalid_argument(
345-
"[vjp] Output shape does not match shape of cotangent.");
347+
std::ostringstream msg;
348+
msg << "[vjp] Output shape " << out.shape()
349+
<< " does not match cotangent shape " << cotans[cotan_index].shape()
350+
<< ".";
351+
if (outputs.size() == 1 && out.size() == 1) {
352+
msg << " If you are using grad your function must return a scalar.";
353+
}
354+
throw std::invalid_argument(msg.str());
346355
}
347356
output_cotan_pairs.emplace_back(i, cotan_index++);
348357
}

python/tests/test_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,11 @@ def test_concatenate(self):
13451345
self.assertEqual(list(c_npy.shape), list(c_mlx.shape))
13461346
self.assertTrue(np.allclose(c_npy, c_mlx, atol=1e-6))
13471347

1348+
with self.assertRaises(ValueError):
1349+
a = mx.array([[1, 2], [1, 2], [1, 2]])
1350+
b = mx.array([1, 2])
1351+
mx.concatenate([a, b], axis=0)
1352+
13481353
def test_pad(self):
13491354
pad_width_and_values = [
13501355
([(1, 1), (1, 1), (1, 1)], 0),

0 commit comments

Comments
 (0)