Skip to content

Commit 020a241

Browse files
authored
Speed up RangesMatrix (#241)
* skip shape check on .full * faster len and getitem<int> * extract_ranges
1 parent 04f32aa commit 020a241

3 files changed

Lines changed: 23 additions & 14 deletions

File tree

include/Ranges.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ class Ranges {
6868
// Support for working with RangesMatrix, which is basically just a list of Ranges
6969
template <typename T>
7070
vector<Ranges<T>> extract_ranges(const bp::object & ival_list) {
71-
vector<Ranges<T>> v(bp::len(ival_list));
72-
for (int i=0; i<bp::len(ival_list); i++)
71+
const int N = bp::len(ival_list);
72+
vector<Ranges<T>> v(N);
73+
for (int i=0; i<N ; i++)
7374
v[i] = bp::extract<Ranges<T>>(ival_list[i])();
7475
return v;
7576
}

python/proj/ranges.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __repr__(self):
3535
return 'RangesMatrix(' + ','.join(map(str, self.shape)) + ')'
3636

3737
def __len__(self):
38-
return self.shape[0]
38+
return len(self.ranges)
3939

4040
def copy(self):
4141
return RangesMatrix([x.copy() for x in self.ranges],
@@ -44,7 +44,7 @@ def copy(self):
4444
def zeros_like(self):
4545
return RangesMatrix([x.zeros_like() for x in self.ranges],
4646
child_shape=self.shape[1:])
47-
47+
4848
def ones_like(self):
4949
return RangesMatrix([x.ones_like() for x in self.ranges],
5050
child_shape=self.shape[1:])
@@ -53,7 +53,7 @@ def buffer(self, buff):
5353
[x.buffer(buff) for x in self.ranges]
5454
## just to make this work like Ranges.buffer()
5555
return self
56-
56+
5757
def buffered(self, buff):
5858
out = self.copy()
5959
[x.buffer(buff) for x in out.ranges]
@@ -93,6 +93,10 @@ def __getitem__(self, index):
9393
new_rm = rm[..., :]
9494
9595
"""
96+
# Short-circuit return if this is a simple integer index.
97+
if isinstance(index, (int, np.int32, np.int64)):
98+
return self.ranges[index]
99+
96100
if not isinstance(index, tuple):
97101
index = (index,)
98102

@@ -128,7 +132,7 @@ def __add__(self, x):
128132
elif self.shape[0] == x.shape[0]:
129133
return self.__class__([r + d for r, d in zip(self.ranges, x)], skip_shape_check=True)
130134
return self.__class__([r + x for r in self.ranges], skip_shape_check=True)
131-
135+
132136
def __mul__(self, x):
133137
if isinstance(x, Ranges):
134138
return self.__class__([d * x for d in self.ranges], skip_shape_check=True)
@@ -247,7 +251,7 @@ def full(shape, fill_value):
247251
return r
248252
return RangesMatrix([RangesMatrix.full(shape[1:], fill_value)
249253
for i in range(shape[0])],
250-
child_shape=shape[1:])
254+
child_shape=shape[1:], skip_shape_check=True)
251255

252256
@classmethod
253257
def zeros(cls, shape):

src/Projection.cxx

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,16 +1889,20 @@ vector<vector<vector<RangesInt32>>> derive_ranges(
18891889
ivals.push_back(vector<vector<RangesInt32>>(1, extract_ranges<int32_t>(thread_intervals)));
18901890
} else if(bp::extract<RangesInt32>(thread_intervals[0][0]).check()) {
18911891
// It's a per-thread RangesMatrix (nthread,ndet,nranges). Promote to single bunch
1892-
vector<vector<RangesInt32>> bunch;
1893-
for (int i=0; i<bp::len(thread_intervals); i++)
1894-
bunch.push_back(extract_ranges<int32_t>(thread_intervals[i]));
1892+
int N = bp::len(thread_intervals);
1893+
vector<vector<RangesInt32>> bunch(N);
1894+
for (int i=0; i<N; i++)
1895+
bunch[i] = extract_ranges<int32_t>(thread_intervals[i]);
18951896
ivals.push_back(bunch);
18961897
} else if(bp::extract<RangesInt32>(thread_intervals[0][0][0]).check()) {
18971898
// It's a full multi-bunch (nbunch,nthread,ndet,nranges) thing.
1898-
for (int i=0; i<bp::len(thread_intervals); i++) {
1899-
vector<vector<RangesInt32>> bunch;
1900-
for (int j=0; j<bp::len(thread_intervals[i]); j++)
1901-
bunch.push_back(extract_ranges<int32_t>(thread_intervals[i][j]));
1899+
const int N = bp::len(thread_intervals);
1900+
for (int i=0; i<N; i++) {
1901+
auto ti_i = thread_intervals[i];
1902+
int M = bp::len(ti_i);
1903+
vector<vector<RangesInt32>> bunch(M);
1904+
for (int j=0; j<M; j++)
1905+
bunch[j] = extract_ranges<int32_t>(ti_i[j]);
19021906
ivals.push_back(bunch);
19031907
}
19041908
} else {

0 commit comments

Comments
 (0)