Skip to content

Commit ffff671

Browse files
authored
Update pre-commit hooks (#984)
1 parent 12d4507 commit ffff671

File tree

10 files changed

+37
-44
lines changed

10 files changed

+37
-44
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
repos:
22
- repo: https://github.com/pre-commit/mirrors-clang-format
3-
rev: v17.0.6
3+
rev: v18.1.3
44
hooks:
55
- id: clang-format
66
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
77
- repo: https://github.com/psf/black-pre-commit-mirror
8-
rev: 24.2.0
8+
rev: 24.3.0
99
hooks:
1010
- id: black
1111
- repo: https://github.com/pycqa/isort

benchmarks/cpp/time_utils.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
1818
<< std::endl;
1919

20-
#define TIMEM(MSG, FUNC, ...) \
21-
std::cout << "Timing " \
22-
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
23-
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
24-
<< std::endl;
20+
#define TIMEM(MSG, FUNC, ...) \
21+
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
22+
<< std::flush << std::setprecision(5) \
23+
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
2524

2625
template <typename F, typename... Args>
2726
double time_fn(F fn, Args&&... args) {

mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ struct Conv2DWeightBlockLoader {
394394
const constant ImplicitGemmConv2DParams* gemm_params_,
395395
uint simd_group_id [[simdgroup_index_in_threadgroup]],
396396
uint simd_lane_id [[thread_index_in_simdgroup]])
397-
: src_ld(params_->wt_strides[0]),
397+
: src_ld(params_ -> wt_strides[0]),
398398
thread_idx(simd_group_id * 32 + simd_lane_id),
399399
bi(thread_idx / TCOLS),
400400
bj(vec_size * (thread_idx % TCOLS)),

mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
244244
const constant ImplicitGemmConv2DParams* gemm_params_,
245245
uint simd_group_id [[simdgroup_index_in_threadgroup]],
246246
uint simd_lane_id [[thread_index_in_simdgroup]])
247-
: src_ld(params_->wt_strides[0]),
247+
: src_ld(params_ -> wt_strides[0]),
248248
thread_idx(simd_group_id * 32 + simd_lane_id),
249249
bi(thread_idx / TCOLS),
250250
bj(vec_size * (thread_idx % TCOLS)),

mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ struct Conv2DWeightBlockLoaderGeneral {
220220
const short base_ww_,
221221
uint simd_group_id [[simdgroup_index_in_threadgroup]],
222222
uint simd_lane_id [[thread_index_in_simdgroup]])
223-
: src_ld(params_->wt_strides[0]),
223+
: src_ld(params_ -> wt_strides[0]),
224224
thread_idx(simd_group_id * 32 + simd_lane_id),
225225
bi(thread_idx / TCOLS),
226226
bj(vec_size * (thread_idx % TCOLS)),

mlx/backend/metal/matmul.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,8 @@ inline auto collapse_batches(const array& a, const array& b) {
197197
std::vector<int> B_bshape{b.shape().begin(), b.shape().end() - 2};
198198
if (A_bshape != B_bshape) {
199199
std::ostringstream msg;
200-
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: "
201-
<< "A " << a.shape() << ", B " << b.shape() << ".";
200+
msg << "[matmul] Got matrices with incorrectly broadcasted shapes: " << "A "
201+
<< a.shape() << ", B " << b.shape() << ".";
202202
throw std::runtime_error(msg.str());
203203
}
204204

@@ -227,9 +227,8 @@ inline auto collapse_batches(const array& a, const array& b, const array& c) {
227227
std::vector<int> C_bshape{c.shape().begin(), c.shape().end() - 2};
228228
if (A_bshape != B_bshape || A_bshape != C_bshape) {
229229
std::ostringstream msg;
230-
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: "
231-
<< "A " << a.shape() << ", B " << b.shape() << ", B " << c.shape()
232-
<< ".";
230+
msg << "[addmm] Got matrices with incorrectly broadcasted shapes: " << "A "
231+
<< a.shape() << ", B " << b.shape() << ", B " << c.shape() << ".";
233232
throw std::runtime_error(msg.str());
234233
}
235234

@@ -332,8 +331,8 @@ void steel_matmul(
332331
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
333332
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
334333
<< "_wm" << wm << "_wn" << wn << "_MN_"
335-
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
336-
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
334+
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
335+
<< ((K % bk == 0) ? "t" : "n") << "aligned";
337336

338337
// Encode and dispatch gemm kernel
339338
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -422,8 +421,8 @@ void steel_matmul(
422421
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
423422
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
424423
<< "_wm" << wm << "_wn" << wn << "_MN_"
425-
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
426-
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
424+
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
425+
<< ((K % bk == 0) ? "t" : "n") << "aligned";
427426

428427
// Encode and dispatch kernel
429428
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -903,8 +902,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
903902
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
904903
<< type_to_name(C_split) << "_bm" << bm << "_bn" << bn << "_bk" << bk
905904
<< "_wm" << wm << "_wn" << wn << "_MN_"
906-
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
907-
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned";
905+
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
906+
<< ((K % bk == 0) ? "t" : "n") << "aligned";
908907

909908
// Encode and dispatch gemm kernel
910909
auto& compute_encoder = d.get_command_encoder(s.index);
@@ -992,8 +991,8 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
992991
<< (transpose_b ? 't' : 'n') << "_" << type_to_name(a) << "_"
993992
<< type_to_name(out) << "_bm" << bm << "_bn" << bn << "_bk" << bk
994993
<< "_wm" << wm << "_wn" << wn << "_MN_"
995-
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned"
996-
<< "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"
994+
<< ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" << "_K_"
995+
<< ((K % bk == 0) ? "t" : "n") << "aligned"
997996
<< ((alpha_ == 1. && beta_ == 1.) ? "_add" : "_axpby");
998997

999998
// Encode and dispatch kernel

mlx/io/load.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ void save(std::shared_ptr<io::Writer> out_stream, array a) {
6363
std::string fortran_order = a.flags().col_contiguous ? "True" : "False";
6464
std::ostringstream header;
6565
header << "{'descr': '" << dtype_to_array_protocol(a.dtype()) << "',"
66-
<< " 'fortran_order': " << fortran_order << ","
67-
<< " 'shape': (";
66+
<< " 'fortran_order': " << fortran_order << "," << " 'shape': (";
6867
for (auto i : a.shape()) {
6968
header << i << ", ";
7069
}

mlx/ops.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -932,15 +932,15 @@ array pad(
932932
if (low_pad_size[i] < 0) {
933933
std::ostringstream msg;
934934
msg << "Invalid low padding size (" << low_pad_size[i]
935-
<< ") passed to pad"
936-
<< " for axis " << i << ". Padding sizes must be non-negative";
935+
<< ") passed to pad" << " for axis " << i
936+
<< ". Padding sizes must be non-negative";
937937
throw std::invalid_argument(msg.str());
938938
}
939939
if (high_pad_size[i] < 0) {
940940
std::ostringstream msg;
941941
msg << "Invalid high padding size (" << high_pad_size[i]
942-
<< ") passed to pad"
943-
<< " for axis " << i << ". Padding sizes must be non-negative";
942+
<< ") passed to pad" << " for axis " << i
943+
<< ". Padding sizes must be non-negative";
944944
throw std::invalid_argument(msg.str());
945945
}
946946

@@ -2508,8 +2508,8 @@ array take_along_axis(
25082508
StreamOrDevice s /* = {} */) {
25092509
if (axis + a.ndim() < 0 || axis >= static_cast<int>(a.ndim())) {
25102510
std::ostringstream msg;
2511-
msg << "[take_along_axis] Received invalid axis "
2512-
<< " for array with " << a.ndim() << " dimensions.";
2511+
msg << "[take_along_axis] Received invalid axis " << " for array with "
2512+
<< a.ndim() << " dimensions.";
25132513
throw std::invalid_argument(msg.str());
25142514
}
25152515

@@ -2904,15 +2904,15 @@ inline std::vector<int> conv_out_shape(
29042904

29052905
if (pads_lo[i - 1] < 0 || pads_hi[i - 1] < 0) {
29062906
std::ostringstream msg;
2907-
msg << "[conv] Padding sizes must be non-negative."
2908-
<< " Got padding " << pads_lo << " | " << pads_hi << ".";
2907+
msg << "[conv] Padding sizes must be non-negative." << " Got padding "
2908+
<< pads_lo << " | " << pads_hi << ".";
29092909
throw std::invalid_argument(msg.str());
29102910
}
29112911

29122912
if (strides[i - 1] <= 0) {
29132913
std::ostringstream msg;
2914-
msg << "[conv] Stride sizes must be positive."
2915-
<< " Got strides " << strides << ".";
2914+
msg << "[conv] Stride sizes must be positive." << " Got strides "
2915+
<< strides << ".";
29162916
throw std::invalid_argument(msg.str());
29172917
}
29182918

@@ -2948,8 +2948,7 @@ inline void run_conv_checks(const array& in, const array& wt, int n_dim) {
29482948
if (in.ndim() != n_dim + 2) {
29492949
std::ostringstream msg;
29502950
msg << "[conv] Invalid input array with " << in.ndim() << " dimensions for "
2951-
<< n_dim << "D convolution."
2952-
<< " Expected an array with " << n_dim + 2
2951+
<< n_dim << "D convolution." << " Expected an array with " << n_dim + 2
29532952
<< " dimensions following the format [N, ..., C_in].";
29542953
throw std::invalid_argument(msg.str());
29552954
}
@@ -3236,8 +3235,7 @@ std::tuple<array, array, array> quantize(
32363235
std::ostringstream msg;
32373236
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
32383237
<< "the quantization group size " << group_size
3239-
<< ". However the provided "
3240-
<< " matrix has shape " << w.shape();
3238+
<< ". However the provided " << " matrix has shape " << w.shape();
32413239
throw std::invalid_argument(msg.str());
32423240
}
32433241

mlx/transforms.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ class Synchronizer : public Primitive {
2424
public:
2525
explicit Synchronizer(Stream stream) : Primitive(stream){};
2626

27-
void eval_cpu(const std::vector<array>&, std::vector<array>&) override{};
28-
void eval_gpu(const std::vector<array>&, std::vector<array>&) override{};
27+
void eval_cpu(const std::vector<array>&, std::vector<array>&) override {};
28+
void eval_gpu(const std::vector<array>&, std::vector<array>&) override {};
2929

3030
DEFINE_PRINT(Synchronize);
3131
};

tests/custom_vjp_tests.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ TEST_CASE("test simple custom vjp") {
1818
fn,
1919
[&](const std::vector<array>&,
2020
const std::vector<array>&,
21-
const std::vector<array>&) {
22-
return std::vector<array>{one, one};
23-
});
21+
const std::vector<array>&) { return std::vector<array>{one, one}; });
2422

2523
auto [z, g] = vjp(fn, {x, y}, {one, one});
2624
CHECK_EQ(z[0].item<float>(), 6.0f);

0 commit comments

Comments
 (0)