Skip to content

Commit e4c824d

Browse files
OctoberChangWei-Cheng Chang
andauthored
Fix train/predict bugs in PairwiseANN (#271)
Co-authored-by: Wei-Cheng Chang <chanweic@amazon.com>
1 parent c884e71 commit e4c824d

File tree

5 files changed

+247
-156
lines changed

5 files changed

+247
-156
lines changed

pecos/ann/pairwise/model.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,24 @@ class PredParams(pecos.BaseParams):
4848
"""Prediction Parameters of PairwiseANN class
4949
5050
Attributes:
51-
topk (int): maximum number of candidates (sorted by distances, nearest first) return by the searcher per query
51+
batch_size (int): maximum number of (input, label) pairs te be inference on for the Searchers
52+
only_topk (int): maximum number of candidates (sorted by distances, nearest first) return by kNN
5253
"""
5354

54-
topk: int = 10
55+
batch_size: int = 1024
56+
only_topk: int = 10
5557

5658
class Searchers(object):
57-
def __init__(self, model, max_batch_size=256, max_only_topk=10, num_searcher=1):
59+
def __init__(self, model, pred_params, num_searcher=1):
5860
self.searchers_ptr = model.fn_dict["searchers_create"](
5961
model.model_ptr,
6062
num_searcher,
6163
)
6264
self.destruct_fn = model.fn_dict["searchers_destruct"]
6365

6466
# searchers also hold the memory of returned np.ndarray
65-
self.max_batch_size = max_batch_size
66-
self.max_only_topk = max_only_topk
67-
max_nnz = max_batch_size * max_only_topk
67+
self.pred_params = pred_params
68+
max_nnz = pred_params.batch_size * pred_params.only_topk
6869
self.Imat = np.zeros(max_nnz, dtype=np.uint32)
6970
self.Mmat = np.zeros(max_nnz, dtype=np.uint32)
7071
self.Dmat = np.zeros(max_nnz, dtype=np.float32)
@@ -214,11 +215,18 @@ def save(self, model_folder):
214215
c_model_dir = f"{model_folder}/c_model"
215216
self.fn_dict["save"](self.model_ptr, c_char_p(c_model_dir.encode("utf-8")))
216217

217-
def searchers_create(self, max_batch_size=256, max_only_topk=10, num_searcher=1):
218+
def get_pred_params(self):
219+
"""Return a deep copy of prediction parameters
220+
221+
Returns:
222+
copied_pred_params (dict): Prediction parameters.
223+
"""
224+
return copy.deepcopy(self.pred_params)
225+
226+
def searchers_create(self, pred_params=None, num_searcher=1):
218227
"""create searchers that pre-allocate intermediate variables (e.g., topk_queue)
219228
Args:
220-
max_batch_size (int): the maximum batch size for the input/label pairs to be inference
221-
max_only_topk (int): the maximum only topk for the kNN to return
229+
pred_params (Pairwise.PredParams, optional): instance of pecos.ann.pairwise.Pairwise.PredParams
222230
num_searcher: number of searcher for multi-thread inference
223231
Returns:
224232
PairwiseANN.Searchers: the pre-allocated PairwiseANN.Searchers (class object)
@@ -227,31 +235,25 @@ def searchers_create(self, max_batch_size=256, max_only_topk=10, num_searcher=1)
227235
raise ValueError("self.model_ptr must exist before using searchers_create()")
228236
if num_searcher <= 0:
229237
raise ValueError("num_searcher={} <= 0 is NOT valid".format(num_searcher))
230-
return PairwiseANN.Searchers(self, max_batch_size, max_only_topk, num_searcher)
231-
232-
def get_pred_params(self):
233-
"""Return a deep copy of prediction parameters
234-
235-
Returns:
236-
copied_pred_params (dict): Prediction parameters.
237-
"""
238-
return copy.deepcopy(self.pred_params)
238+
pred_params = self.get_pred_params() if pred_params is None else pred_params
239+
return PairwiseANN.Searchers(self, pred_params, num_searcher)
239240

240-
def predict(self, input_feat, label_keys, searchers, pred_params=None, is_same_input=False):
241+
def predict(self, input_feat, label_keys, searchers, is_same_input=False):
241242
"""predict with multi-thread. The searchers are required to be provided.
242243
Args:
243244
input_feat (numpy.array or smat.csr_matrix): input feature matrix (first key) to find kNN.
244-
if is_same_input == False, the shape should be (batch_size, feat_dim)
245-
if is_same_input == True, the shape should be (1, feat_dim)
246-
label_keys (numpy.array): the label keys (second key) to find kNN. The shape should be (batch_size, ).
247-
searchers (c_void_p): pointer to C/C++ vector<pecos::ann::PairwiseANN:Searcher>. Created by PairwiseANN.searchers_create().
248-
pred_params (Pairwise.PredParams, optional): instance of pecos.ann.pairwise.Pairwise.PredParams.
245+
if is_same_input == False, the shape should be (batch_size, feat_dim).
246+
if is_same_input == True, the shape should be (1, feat_dim).
247+
label_keys (numpy.array): the label keys (second key) to find kNN.
248+
The shape should be (batch_size, ).
249+
searchers (c_void_p): pointer to C/C++ vector<pecos::ann::PairwiseANN:Searcher>.
250+
Created by PairwiseANN.searchers_create().
249251
is_same_input (bool): whether to use the same first row of X to do prediction.
250252
For real-time inference with same input query, set is_same_input = True.
251253
For batch prediction with varying input querues, set is_same_input = False.
252254
Returns:
253255
Imat (np.array): returned kNN input key indices. Shape of (batch_size, topk)
254-
Mmat (np.array): returned kNN masking array. 1/0 mean value is or is not presented. Shape of (batch_size, topk)
256+
Mmat (np.array): returned kNN masking array. 1/0 mean value IS/ISNOT presented. Shape of (batch_size, topk)
255257
Dmat (np.array): returned kNN distance array. Shape of (batch_size, topk)
256258
Vmat (np.array): returned kNN value array. Shape of (batch_size, topk)
257259
"""
@@ -273,19 +275,16 @@ def predict(self, input_feat, label_keys, searchers, pred_params=None, is_same_i
273275
if not is_same_input and input_feat_py.rows != label_keys.shape[0]:
274276
raise ValueError(f"input_feat_py.rows != label_keys.shape[0]")
275277

276-
batch_size = label_keys.shape[0]
277-
pred_params = self.get_pred_params() if pred_params is None else pred_params
278-
only_topk = pred_params.topk
279-
cur_nnz = batch_size * only_topk
280-
if batch_size > searchers.max_batch_size:
281-
raise ValueError(f"cur_batch_size > searchers.max_batch_size")
282-
if only_topk > searchers.max_only_topk:
283-
raise ValueError(f"cur_only_topk > searchers.max_only_topk")
278+
cur_bsz = label_keys.shape[0]
279+
if cur_bsz > searchers.pred_params.batch_size:
280+
raise ValueError(f"cur_batch_size > searchers.batch_size!")
281+
only_topk = searchers.pred_params.only_topk
282+
cur_nnz = cur_bsz * only_topk
284283

285284
searchers.reset(cur_nnz)
286285
self.fn_dict["predict"](
287286
searchers.ctypes(),
288-
batch_size,
287+
cur_bsz,
289288
only_topk,
290289
input_feat_py,
291290
label_keys.ctypes.data_as(POINTER(c_uint32)),
@@ -295,8 +294,8 @@ def predict(self, input_feat, label_keys, searchers, pred_params=None, is_same_i
295294
searchers.Vmat.ctypes.data_as(POINTER(c_float)),
296295
c_bool(is_same_input),
297296
)
298-
Imat = searchers.Imat[:cur_nnz].reshape(batch_size, only_topk)
299-
Mmat = searchers.Mmat[:cur_nnz].reshape(batch_size, only_topk)
300-
Dmat = searchers.Dmat[:cur_nnz].reshape(batch_size, only_topk)
301-
Vmat = searchers.Vmat[:cur_nnz].reshape(batch_size, only_topk)
297+
Imat = searchers.Imat[:cur_nnz].reshape(cur_bsz, only_topk)
298+
Mmat = searchers.Mmat[:cur_nnz].reshape(cur_bsz, only_topk)
299+
Dmat = searchers.Dmat[:cur_nnz].reshape(cur_bsz, only_topk)
300+
Vmat = searchers.Vmat[:cur_nnz].reshape(cur_bsz, only_topk)
302301
return Imat, Mmat, Dmat, Vmat

pecos/core/ann/pairwise.hpp

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ namespace ann {
8181
mmap_s.fput_one<index_type>(X.cols);
8282
mmap_s.fput_one<mem_index_type>(nnz);
8383
mmap_s.fput_multiple<value_type>(X.val, nnz);
84-
}
84+
}
8585

8686
template<class MAT_T>
8787
void load_mat(MAT_T &X, mmap_util::MmapStore& mmap_s) {
@@ -99,7 +99,7 @@ namespace ann {
9999
X.cols = mmap_s.fget_one<index_type>();
100100
auto nnz = mmap_s.fget_one<mem_index_type>();
101101
X.val = mmap_s.fget_multiple<value_type>(nnz);
102-
}
102+
}
103103

104104
template <typename T1, typename T2>
105105
struct KeyValPair {
@@ -111,16 +111,16 @@ namespace ann {
111111
bool operator<(const KeyValPair<T1, T2>& other) const { return input_key_dist < other.input_key_dist; }
112112
bool operator>(const KeyValPair<T1, T2>& other) const { return input_key_dist > other.input_key_dist; }
113113
};
114-
115-
// PairwiseANN Interface
114+
115+
// PairwiseANN Interface
116116
template<class FeatVec_T, class MAT_T>
117117
struct PairwiseANN {
118118
typedef FeatVec_T feat_vec_t;
119119
typedef MAT_T mat_t;
120120
typedef pecos::ann::KeyValPair<index_type, value_type> pair_t;
121121
typedef pecos::ann::heap_t<pair_t, std::less<pair_t>> max_heap_t;
122122

123-
struct Searcher {
123+
struct Searcher {
124124
typedef PairwiseANN<feat_vec_t, mat_t> pairwise_ann_t;
125125

126126
const pairwise_ann_t* pairwise_ann;
@@ -132,8 +132,8 @@ namespace ann {
132132

133133
max_heap_t& predict_single(const feat_vec_t& query_vec, const index_type label_key, index_type topk) {
134134
return pairwise_ann->predict_single(query_vec, label_key, topk, *this);
135-
}
136-
};
135+
}
136+
};
137137

138138
Searcher create_searcher() const {
139139
return Searcher(this);
@@ -143,7 +143,7 @@ namespace ann {
143143
index_type num_input_keys; // N
144144
index_type num_label_keys; // L
145145
index_type feat_dim; // d
146-
146+
147147
// matrices
148148
pecos::csc_t Y_csc; // shape of [N, L]
149149
mat_t X_trn; // shape of [N, d]
@@ -152,7 +152,14 @@ namespace ann {
152152
pecos::mmap_util::MmapStore mmap_store;
153153

154154
// destructor
155-
~PairwiseANN() {}
155+
~PairwiseANN() {
156+
// If mmap_store is not open for read, then the memory of Y/X is owned by this class
157+
// Thus, we need to explicitly free the underlying memory of Y/X during destructor
158+
if (!mmap_store.is_open_for_read()) {
159+
this->Y_csc.free_underlying_memory();
160+
this->X_trn.free_underlying_memory();
161+
}
162+
}
156163

157164
static nlohmann::json load_config(const std::string& filepath) {
158165
std::ifstream loadfile(filepath);
@@ -215,7 +222,7 @@ namespace ann {
215222
save_mat(X_trn, mmap_s);
216223
mmap_s.close();
217224
}
218-
225+
219226
void load(const std::string& model_dir, bool lazy_load = false) {
220227
auto config = load_config(model_dir + "/config.json");
221228
std::string version = config.find("version") != config.end() ? config["version"] : "not found";
@@ -248,9 +255,11 @@ namespace ann {
248255
this->num_input_keys = Y_csc.rows;
249256
this->num_label_keys = Y_csc.cols;
250257
this->feat_dim = X_trn.cols;
251-
// matrices
252-
this->Y_csc = Y_csc;
253-
this->X_trn = X_trn;
258+
// Deepcopy the memory of X/Y.
259+
// Otherwise, after Python API of PairwiseANN.train(),
260+
// the input matrices pX/pY can be modified or deleted.
261+
this->Y_csc = Y_csc.deep_copy();
262+
this->X_trn = X_trn.deep_copy();
254263
}
255264

256265
max_heap_t& predict_single(

pecos/core/libpecos.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ extern "C" {
543543
void c_pairwise_ann_predict ## SUFFIX( \
544544
void* searchers_ptr, \
545545
uint32_t batch_size, \
546-
uint32_t topk, \
546+
uint32_t only_topk, \
547547
const PY_MAT* pQ, \
548548
uint32_t* label_keys, \
549549
uint32_t* ret_Imat, \
@@ -559,9 +559,9 @@ extern "C" {
559559
int tid = omp_get_thread_num(); \
560560
auto input_key = (is_same_input ? 0 : bidx); \
561561
auto label_key = label_keys[bidx]; \
562-
auto& ret_pairs = searchers[tid].predict_single(Q_tst.get_row(input_key), label_key, topk); \
562+
auto& ret_pairs = searchers[tid].predict_single(Q_tst.get_row(input_key), label_key, only_topk); \
563563
for (uint32_t k=0; k < ret_pairs.size(); k++) { \
564-
uint64_t offset = static_cast<uint64_t>(bidx) * topk; \
564+
uint64_t offset = static_cast<uint64_t>(bidx) * only_topk; \
565565
ret_Imat[offset + k] = ret_pairs[k].input_key_idx; \
566566
ret_Dmat[offset + k] = ret_pairs[k].input_key_dist; \
567567
ret_Vmat[offset + k] = ret_pairs[k].input_label_val; \

pecos/core/utils/matrix.hpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ namespace pecos {
106106
if(touched_indices[i] < len) {
107107
touched_indices[write_pos] = touched_indices[i];
108108
write_pos += 1;
109-
}
109+
}
110110
}
111111
nr_touch = write_pos;
112112
}
@@ -540,6 +540,34 @@ namespace pecos {
540540
mem_index_type get_nnz() const {
541541
return static_cast<mem_index_type>(rows) * static_cast<mem_index_type>(cols);
542542
}
543+
544+
// Frees the underlying memory of the matrix (i.e., col_ptr, row_idx, and val arrays)
545+
// Every function in the inference code that returns a matrix has allocated memory, and
546+
// therefore one should call this function to free that memory.
547+
void free_underlying_memory() {
548+
if (val) {
549+
delete[] val;
550+
val = nullptr;
551+
}
552+
}
553+
554+
// Creates a deep copy of this matrix
555+
// This allocates memory, so one should call free_underlying_memory on the copy when
556+
// one is finished using it.
557+
drm_t deep_copy() const {
558+
mem_index_type nnz = get_nnz();
559+
drm_t res;
560+
res.allocate(rows, cols, nnz);
561+
std::memcpy(res.val, val, sizeof(value_type) * nnz);
562+
return res;
563+
}
564+
565+
void allocate(index_type rows, index_type cols, mem_index_type nnz) {
566+
this->rows = rows;
567+
this->cols = cols;
568+
val = new value_type[nnz];
569+
}
570+
543571
};
544572

545573
struct dcm_t { // Dense Column Majored Matrix

0 commit comments

Comments
 (0)