Skip to content

Commit 0c3e0ad

Browse files
authored
[Embedding] Refactor ShrinkPolicy class and shrink functions in Embedding. (#861)
Signed-off-by: lixy9474 <[email protected]>
1 parent f178454 commit 0c3e0ad

14 files changed

+176
-169
lines changed

Diff for: tensorflow/core/framework/embedding/dram_leveldb_storage.h

+3-9
Original file line numberDiff line numberDiff line change
@@ -156,15 +156,9 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
156156
return Status::OK();
157157
}
158158

159-
Status Shrink(int64 value_len) override {
160-
dram_->Shrink(value_len);
161-
leveldb_->Shrink(value_len);
162-
return Status::OK();
163-
}
164-
165-
Status Shrink(int64 global_step, int64 steps_to_live) override {
166-
dram_->Shrink(global_step, steps_to_live);
167-
leveldb_->Shrink(global_step, steps_to_live);
159+
Status Shrink(const ShrinkArgs& shrink_args) override {
160+
dram_->Shrink(shrink_args);
161+
leveldb_->Shrink(shrink_args);
168162
return Status::OK();
169163
}
170164

Diff for: tensorflow/core/framework/embedding/dram_pmem_storage.h

+3-9
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,9 @@ class DramPmemStorage : public MultiTierStorage<K, V> {
154154
return Status::OK();
155155
}
156156

157-
Status Shrink(int64 value_len) override {
158-
dram_->Shrink(value_len);
159-
pmem_->Shrink(value_len);
160-
return Status::OK();
161-
}
162-
163-
Status Shrink(int64 global_step, int64 steps_to_live) override {
164-
dram_->Shrink(global_step, steps_to_live);
165-
pmem_->Shrink(global_step, steps_to_live);
157+
Status Shrink(const ShrinkArgs& shrink_args) override {
158+
dram_->Shrink(shrink_args);
159+
pmem_->Shrink(shrink_args);
166160
return Status::OK();
167161
}
168162

Diff for: tensorflow/core/framework/embedding/dram_ssd_storage.h

+3-9
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,9 @@ class DramSsdHashStorage : public MultiTierStorage<K, V> {
147147
return Status::OK();
148148
}
149149

150-
Status Shrink(int64 value_len) override {
151-
dram_->Shrink(value_len);
152-
ssd_hash_->Shrink(value_len);
153-
return Status::OK();
154-
}
155-
156-
Status Shrink(int64 global_step, int64 steps_to_live) override {
157-
dram_->Shrink(global_step, steps_to_live);
158-
ssd_hash_->Shrink(global_step, steps_to_live);
150+
Status Shrink(const ShrinkArgs& shrink_args) override {
151+
dram_->Shrink(shrink_args);
152+
ssd_hash_->Shrink(shrink_args);
159153
return Status::OK();
160154
}
161155

Diff for: tensorflow/core/framework/embedding/embedding_var.h

+4-11
Original file line numberDiff line numberDiff line change
@@ -373,10 +373,6 @@ class EmbeddingVar : public ResourceBase {
373373
return emb_config_.steps_to_live;
374374
}
375375

376-
float GetL2WeightThreshold() {
377-
return emb_config_.l2_weight_threshold;
378-
}
379-
380376
bool IsMultiLevel() {
381377
return storage_->IsMultiLevel();
382378
}
@@ -553,13 +549,10 @@ class EmbeddingVar : public ResourceBase {
553549
return storage_;
554550
}
555551

556-
Status Shrink() {
557-
return storage_->Shrink(value_len_);
558-
}
559-
560-
Status Shrink(int64 gs) {
561-
if (emb_config_.steps_to_live > 0) {
562-
return storage_->Shrink(gs, emb_config_.steps_to_live);
552+
Status Shrink(embedding::ShrinkArgs& shrink_args) {
553+
if (emb_config_.is_primary()) {
554+
shrink_args.value_len = value_len_;
555+
return storage_->Shrink(shrink_args);
563556
} else {
564557
return Status::OK();
565558
}

Diff for: tensorflow/core/framework/embedding/globalstep_shrink_policy.h

+26-19
Original file line numberDiff line numberDiff line change
@@ -26,37 +26,44 @@ namespace embedding {
2626
template<typename K, typename V>
2727
class GlobalStepShrinkPolicy : public ShrinkPolicy<K, V> {
2828
public:
29-
GlobalStepShrinkPolicy(
30-
KVInterface<K, V>* kv,
31-
Allocator* alloc,
32-
int slot_num)
33-
: ShrinkPolicy<K, V>(kv, alloc, slot_num) {}
29+
GlobalStepShrinkPolicy(int64 steps_to_live,
30+
Allocator* alloc,
31+
KVInterface<K, V>* kv)
32+
: steps_to_live_(steps_to_live),
33+
kv_(kv),
34+
ShrinkPolicy<K, V>(alloc) {}
3435

3536
TF_DISALLOW_COPY_AND_ASSIGN(GlobalStepShrinkPolicy);
3637

37-
void Shrink(int64 global_step, int64 steps_to_live) {
38-
ShrinkPolicy<K, V>::ReleaseDeleteValues();
39-
ShrinkPolicy<K, V>::GetSnapshot();
40-
FilterToDelete(global_step, steps_to_live);
38+
void Shrink(const ShrinkArgs& shrink_args) override {
39+
ShrinkPolicy<K, V>::ReleaseValuePtrs();
40+
std::vector<K> key_list;
41+
std::vector<ValuePtr<V>*> value_list;
42+
kv_->GetSnapshot(&key_list, &value_list);
43+
FilterToDelete(shrink_args.global_step,
44+
key_list, value_list);
4145
}
4246

4347
private:
44-
void FilterToDelete(int64 global_step, int64 steps_to_live) {
45-
for (int64 i = 0; i < ShrinkPolicy<K, V>::key_list_.size(); ++i) {
46-
int64 version = ShrinkPolicy<K, V>::value_list_[i]->GetStep();
48+
void FilterToDelete(int64 global_step,
49+
const std::vector<K>& key_list,
50+
const std::vector<ValuePtr<V>*>& value_list) {
51+
for (int64 i = 0; i < key_list.size(); ++i) {
52+
int64 version = value_list[i]->GetStep();
4753
if (version == -1) {
48-
ShrinkPolicy<K, V>::value_list_[i]->SetStep(global_step);
54+
value_list[i]->SetStep(global_step);
4955
} else {
50-
if (global_step - version > steps_to_live) {
51-
ShrinkPolicy<K, V>::kv_->Remove(ShrinkPolicy<K, V>::key_list_[i]);
52-
ShrinkPolicy<K, V>::to_delete_.emplace_back(
53-
ShrinkPolicy<K, V>::value_list_[i]);
56+
if (global_step - version > steps_to_live_) {
57+
kv_->Remove(key_list[i]);
58+
ShrinkPolicy<K, V>::EmplacePointer(value_list[i]);
5459
}
5560
}
5661
}
57-
ShrinkPolicy<K, V>::key_list_.clear();
58-
ShrinkPolicy<K, V>::value_list_.clear();
5962
}
63+
64+
private:
65+
int64 steps_to_live_;
66+
KVInterface<K, V>* kv_;
6067
};
6168
} // embedding
6269
} // tensorflow

Diff for: tensorflow/core/framework/embedding/hbm_dram_ssd_storage.h

+4-11
Original file line numberDiff line numberDiff line change
@@ -387,17 +387,10 @@ class HbmDramSsdStorage : public MultiTierStorage<K, V> {
387387
LOG(FATAL)<<"HbmDramSsdStorage dosen't support GetSnaoshot.";
388388
}
389389

390-
Status Shrink(int64 value_len) override {
391-
hbm_->Shrink(value_len);
392-
dram_->Shrink(value_len);
393-
ssd_->Shrink(value_len);
394-
return Status::OK();
395-
}
396-
397-
Status Shrink(int64 global_step, int64 steps_to_live) override {
398-
hbm_->Shrink(global_step, steps_to_live);
399-
dram_->Shrink(global_step, steps_to_live);
400-
ssd_->Shrink(global_step, steps_to_live);
390+
Status Shrink(const ShrinkArgs& shrink_args) override {
391+
hbm_->Shrink(shrink_args);
392+
dram_->Shrink(shrink_args);
393+
ssd_->Shrink(shrink_args);
401394
return Status::OK();
402395
}
403396

Diff for: tensorflow/core/framework/embedding/hbm_dram_storage.h

+3-9
Original file line numberDiff line numberDiff line change
@@ -352,15 +352,9 @@ class HbmDramStorage : public MultiTierStorage<K, V> {
352352
return temp_hbm_key_list.size() + temp_dram_key_list.size();
353353
}
354354

355-
Status Shrink(int64 value_len) override {
356-
hbm_->Shrink(value_len);
357-
dram_->Shrink(value_len);
358-
return Status::OK();
359-
}
360-
361-
Status Shrink(int64 global_step, int64 steps_to_live) override {
362-
hbm_->Shrink(global_step, steps_to_live);
363-
dram_->Shrink(global_step, steps_to_live);
355+
Status Shrink(const ShrinkArgs& shrink_args) override {
356+
hbm_->Shrink(shrink_args);
357+
dram_->Shrink(shrink_args);
364358
return Status::OK();
365359
}
366360

Diff for: tensorflow/core/framework/embedding/l2weight_shrink_policy.h

+27-22
Original file line numberDiff line numberDiff line change
@@ -27,46 +27,51 @@ template<typename K, typename V>
2727
class L2WeightShrinkPolicy : public ShrinkPolicy<K, V> {
2828
public:
2929
L2WeightShrinkPolicy(float l2_weight_threshold,
30-
int64 primary_index, int64 primary_offset,
31-
KVInterface<K, V>* kv, Allocator* alloc,
32-
int slot_num)
33-
: l2_weight_threshold_(l2_weight_threshold),
34-
primary_index_(primary_index), primary_offset_(primary_offset),
35-
ShrinkPolicy<K, V>(kv, alloc, slot_num) {}
30+
int64 index,
31+
int64 offset,
32+
Allocator* alloc,
33+
KVInterface<K, V>* kv)
34+
: index_(index),
35+
offset_(offset),
36+
kv_(kv),
37+
l2_weight_threshold_(l2_weight_threshold),
38+
ShrinkPolicy<K, V>(alloc) {}
3639

3740
TF_DISALLOW_COPY_AND_ASSIGN(L2WeightShrinkPolicy);
3841

39-
void Shrink(int64 value_len) {
40-
ShrinkPolicy<K, V>::ReleaseDeleteValues();
41-
ShrinkPolicy<K, V>::GetSnapshot();
42-
FilterToDelete(value_len);
42+
void Shrink(const ShrinkArgs& shrink_args) override {
43+
ShrinkPolicy<K, V>::ReleaseValuePtrs();
44+
std::vector<K> key_list;
45+
std::vector<ValuePtr<V>*> value_list;
46+
kv_->GetSnapshot(&key_list, &value_list);
47+
FilterToDelete(shrink_args.value_len,
48+
key_list, value_list);
4349
}
4450

45-
private:
46-
void FilterToDelete(int64 value_len) {
47-
for (int64 i = 0; i < ShrinkPolicy<K, V>::key_list_.size(); ++i) {
48-
V* val = ShrinkPolicy<K, V>::value_list_[i]->GetValue(
49-
primary_index_, primary_offset_);
51+
private:
52+
void FilterToDelete(int64 value_len,
53+
const std::vector<K>& key_list,
54+
const std::vector<ValuePtr<V>*>& value_list) {
55+
for (int64 i = 0; i < key_list.size(); ++i) {
56+
V* val = value_list[i]->GetValue(index_, offset_);
5057
if (val != nullptr) {
5158
V l2_weight = (V)0.0;
5259
for (int64 j = 0; j < value_len; j++) {
5360
l2_weight += val[j] * val[j];
5461
}
5562
l2_weight *= (V)0.5;
5663
if (l2_weight < (V)l2_weight_threshold_) {
57-
ShrinkPolicy<K, V>::kv_->Remove(ShrinkPolicy<K, V>::key_list_[i]);
58-
ShrinkPolicy<K, V>::to_delete_.emplace_back(
59-
ShrinkPolicy<K, V>::value_list_[i]);
64+
kv_->Remove(key_list[i]);
65+
ShrinkPolicy<K, V>::EmplacePointer(value_list[i]);
6066
}
6167
}
6268
}
63-
ShrinkPolicy<K, V>::key_list_.clear();
64-
ShrinkPolicy<K, V>::value_list_.clear();
6569
}
6670

6771
private:
68-
int64 primary_index_; // Shrink only handle primary slot
69-
int64 primary_offset_;
72+
int64 index_;
73+
int64 offset_;
74+
KVInterface<K, V>* kv_;
7075
float l2_weight_threshold_;
7176
};
7277
} // embedding

Diff for: tensorflow/core/framework/embedding/shrink_policy.h

+37-22
Original file line numberDiff line numberDiff line change
@@ -23,40 +23,55 @@ namespace tensorflow {
2323
template<typename V>
2424
class ValuePtr;
2525

26+
class Allocator;
27+
2628
namespace embedding {
29+
struct ShrinkArgs {
30+
ShrinkArgs(): global_step(0), value_len(0) {}
31+
32+
ShrinkArgs(int64 global_step,
33+
int64 value_len)
34+
: global_step(global_step),
35+
value_len(value_len) {}
36+
int64 global_step;
37+
int64 value_len;
38+
};
39+
2740
template<typename K, typename V>
2841
class ShrinkPolicy {
2942
public:
30-
ShrinkPolicy(KVInterface<K, V>* kv, Allocator* alloc, int slot_num)
31-
: kv_(kv), alloc_(alloc),
32-
slot_num_(slot_num), shrink_count_(0) {}
43+
ShrinkPolicy(Allocator* alloc): alloc_(alloc) {}
44+
virtual ~ShrinkPolicy() {}
3345

3446
TF_DISALLOW_COPY_AND_ASSIGN(ShrinkPolicy);
3547

36-
inline Status GetSnapshot() {
37-
shrink_count_ = (shrink_count_ + 1) % slot_num_;
38-
return kv_->GetSnapshot(&key_list_, &value_list_);
48+
virtual void Shrink(const ShrinkArgs& shrink_args) = 0;
49+
50+
protected:
51+
void EmplacePointer(ValuePtr<V>* value_ptr) {
52+
to_delete_.emplace_back(value_ptr);
3953
}
40-
41-
void ReleaseDeleteValues() {
42-
if (shrink_count_ == 0) {
43-
for (auto it : to_delete_) {
44-
it->Destroy(alloc_);
45-
delete it;
46-
}
47-
to_delete_.clear();
54+
55+
void ReleaseValuePtrs() {
56+
for (auto it : to_delete_) {
57+
it->Destroy(alloc_);
58+
delete it;
4859
}
60+
to_delete_.clear();
4961
}
50-
51-
protected:
52-
std::vector<K> key_list_;
53-
std::vector<ValuePtr<V>*> value_list_;
62+
protected:
5463
std::vector<ValuePtr<V>*> to_delete_;
55-
56-
KVInterface<K, V>* kv_;
64+
private:
5765
Allocator* alloc_;
58-
int slot_num_;
59-
int shrink_count_;
66+
};
67+
68+
template<typename K, typename V>
69+
class NonShrinkPolicy: public ShrinkPolicy<K, V> {
70+
public:
71+
NonShrinkPolicy(): ShrinkPolicy<K, V>(nullptr) {}
72+
TF_DISALLOW_COPY_AND_ASSIGN(NonShrinkPolicy);
73+
74+
void Shrink(const ShrinkArgs& shrink_args) {}
6075
};
6176
} // embedding
6277
} // tensorflow

0 commit comments

Comments
 (0)