Skip to content

Commit 4cd9ed8

Browse files
authored
[Embedding] Refactor the code of Save Op for EmbeddingVariable. (#900)
Signed-off-by: lixy9474 <[email protected]>
1 parent 2065fc0 commit 4cd9ed8

31 files changed

+1191
-1554
lines changed

Diff for: tensorflow/core/framework/embedding/config.proto

+10
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,13 @@ enum EmbeddingVariableType {
4646
IMMUTABLE = 0;
4747
MUTABLE = 1;
4848
}
49+
50+
enum ValuePtrStatus {
51+
OK = 0;
52+
IS_DELETED = 1;
53+
}
54+
55+
enum ValuePosition {
56+
IN_DRAM = 0;
57+
NOT_IN_DRAM = 1;
58+
}

Diff for: tensorflow/core/framework/embedding/dram_leveldb_storage.h

+45-41
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,6 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
111111
return false;
112112
}
113113

114-
void iterator_mutex_lock() override {
115-
leveldb_->get_mutex()->lock();
116-
}
117-
118-
void iterator_mutex_unlock() override {
119-
leveldb_->get_mutex()->unlock();
120-
}
121-
122114
int64 Size() const override {
123115
int64 total_size = dram_->Size();
124116
total_size += leveldb_->Size();
@@ -145,46 +137,58 @@ class DramLevelDBStore : public MultiTierStorage<K, V> {
145137
return -1;
146138
}
147139

148-
Status GetSnapshot(std::vector<K>* key_list,
149-
std::vector<ValuePtr<V>*>* value_ptr_list) override {
150-
{
151-
mutex_lock l(*(dram_->get_mutex()));
152-
TF_CHECK_OK(dram_->GetSnapshot(key_list, value_ptr_list));
140+
Status Save(
141+
const string& tensor_name,
142+
const string& prefix,
143+
BundleWriter* writer,
144+
const EmbeddingConfig& emb_config,
145+
ShrinkArgs& shrink_args,
146+
int64 value_len,
147+
V* default_value) override {
148+
std::vector<K> key_list, tmp_leveldb_key_list;
149+
std::vector<ValuePtr<V>*> value_ptr_list, tmp_leveldb_value_list;
150+
TF_CHECK_OK(dram_->GetSnapshot(&key_list, &value_ptr_list));
151+
152+
TF_CHECK_OK(leveldb_->GetSnapshot(
153+
&tmp_leveldb_key_list, &tmp_leveldb_value_list));
154+
155+
for (int64 i = 0; i < tmp_leveldb_value_list.size(); i++) {
156+
tmp_leveldb_value_list[i]->SetPtr((V*)ValuePosition::NOT_IN_DRAM);
157+
tmp_leveldb_value_list[i]->SetInitialized(emb_config.primary_emb_index);
153158
}
154-
{
155-
mutex_lock l(*(leveldb_->get_mutex()));
156-
TF_CHECK_OK(leveldb_->GetSnapshot(key_list, value_ptr_list));
159+
160+
std::vector<K> leveldb_key_list;
161+
for (int64 i = 0; i < tmp_leveldb_key_list.size(); i++) {
162+
Status s = dram_->Contains(tmp_leveldb_key_list[i]);
163+
if (!s.ok()) {
164+
key_list.emplace_back(tmp_leveldb_key_list[i]);
165+
leveldb_key_list.emplace_back(tmp_leveldb_key_list[i]);
166+
value_ptr_list.emplace_back(tmp_leveldb_value_list[i]);
167+
}
157168
}
158-
return Status::OK();
159-
}
160169

161-
Status Shrink(const ShrinkArgs& shrink_args) override {
162-
dram_->Shrink(shrink_args);
163-
leveldb_->Shrink(shrink_args);
164-
return Status::OK();
165-
}
170+
ValueIterator<V>* value_iter =
171+
leveldb_->GetValueIterator(
172+
leveldb_key_list, emb_config.emb_index, value_len);
166173

167-
int64 GetSnapshot(std::vector<K>* key_list,
168-
std::vector<V* >* value_list,
169-
std::vector<int64>* version_list,
170-
std::vector<int64>* freq_list,
171-
const EmbeddingConfig& emb_config,
172-
FilterPolicy<K, V, EmbeddingVar<K, V>>* filter,
173-
embedding::Iterator** it) override {
174-
{
175-
mutex_lock l(*(dram_->get_mutex()));
176-
std::vector<ValuePtr<V>*> value_ptr_list;
177-
std::vector<K> key_list_tmp;
178-
TF_CHECK_OK(dram_->GetSnapshot(&key_list_tmp, &value_ptr_list));
179-
MultiTierStorage<K, V>::SetListsForCheckpoint(
180-
key_list_tmp, value_ptr_list, emb_config,
181-
key_list, value_list, version_list, freq_list);
182-
}
183174
{
184175
mutex_lock l(*(leveldb_->get_mutex()));
185-
*it = leveldb_->GetIterator();
176+
TF_CHECK_OK((Storage<K, V>::SaveToCheckpoint(
177+
tensor_name, writer,
178+
emb_config,
179+
value_len, default_value,
180+
key_list,
181+
value_ptr_list,
182+
value_iter)));
186183
}
187-
return key_list->size();
184+
185+
for (auto it: tmp_leveldb_value_list) {
186+
delete it;
187+
}
188+
189+
delete value_iter;
190+
191+
return Status::OK();
188192
}
189193

190194
Status Eviction(K* evict_ids, int64 evict_size) override {

Diff for: tensorflow/core/framework/embedding/dram_pmem_storage.h

+32-50
Original file line numberDiff line numberDiff line change
@@ -150,59 +150,41 @@ class DramPmemStorage : public MultiTierStorage<K, V> {
150150
return -1;
151151
}
152152

153-
Status GetSnapshot(std::vector<K>* key_list,
154-
std::vector<ValuePtr<V>* >* value_ptr_list) override {
155-
{
156-
mutex_lock l(*(dram_->get_mutex()));
157-
TF_CHECK_OK(dram_->GetSnapshot(key_list, value_ptr_list));
158-
}
159-
{
160-
mutex_lock l(*(pmem_->get_mutex()));
161-
TF_CHECK_OK(pmem_->GetSnapshot(key_list, value_ptr_list));
153+
Status Save(
154+
const string& tensor_name,
155+
const string& prefix,
156+
BundleWriter* writer,
157+
const EmbeddingConfig& emb_config,
158+
ShrinkArgs& shrink_args,
159+
int64 value_len,
160+
V* default_value) override {
161+
std::vector<K> key_list, tmp_pmem_key_list;
162+
std::vector<ValuePtr<V>*> value_ptr_list, tmp_pmem_value_list;
163+
164+
TF_CHECK_OK(dram_->GetSnapshot(&key_list, &value_ptr_list));
165+
dram_->Shrink(key_list, value_ptr_list, shrink_args, value_len);
166+
167+
TF_CHECK_OK(pmem_->GetSnapshot(&tmp_pmem_key_list,
168+
&tmp_pmem_value_list));
169+
pmem_->Shrink(tmp_pmem_key_list, tmp_pmem_value_list,
170+
shrink_args, value_len);
171+
172+
for (int64 i = 0; i < tmp_pmem_key_list.size(); i++) {
173+
Status s = dram_->Contains(tmp_pmem_key_list[i]);
174+
if (!s.ok()) {
175+
key_list.emplace_back(tmp_pmem_key_list[i]);
176+
value_ptr_list.emplace_back(tmp_pmem_value_list[i]);
177+
}
162178
}
163-
return Status::OK();
164-
}
165179

166-
Status Shrink(const ShrinkArgs& shrink_args) override {
167-
dram_->Shrink(shrink_args);
168-
pmem_->Shrink(shrink_args);
169-
return Status::OK();
170-
}
180+
TF_CHECK_OK((Storage<K, V>::SaveToCheckpoint(
181+
tensor_name, writer,
182+
emb_config,
183+
value_len, default_value,
184+
key_list,
185+
value_ptr_list)));
171186

172-
void iterator_mutex_lock() override {
173-
return;
174-
}
175-
176-
void iterator_mutex_unlock() override {
177-
return;
178-
}
179-
180-
int64 GetSnapshot(std::vector<K>* key_list,
181-
std::vector<V* >* value_list,
182-
std::vector<int64>* version_list,
183-
std::vector<int64>* freq_list,
184-
const EmbeddingConfig& emb_config,
185-
FilterPolicy<K, V, EmbeddingVar<K, V>>* filter,
186-
embedding::Iterator** it) override {
187-
{
188-
mutex_lock l(*(dram_->get_mutex()));
189-
std::vector<ValuePtr<V>*> value_ptr_list;
190-
std::vector<K> key_list_tmp;
191-
TF_CHECK_OK(dram_->GetSnapshot(&key_list_tmp, &value_ptr_list));
192-
MultiTierStorage<K, V>::SetListsForCheckpoint(
193-
key_list_tmp, value_ptr_list, emb_config,
194-
key_list, value_list, version_list, freq_list);
195-
}
196-
{
197-
mutex_lock l(*(pmem_->get_mutex()));
198-
std::vector<ValuePtr<V>*> value_ptr_list;
199-
std::vector<K> key_list_tmp;
200-
TF_CHECK_OK(pmem_->GetSnapshot(&key_list_tmp, &value_ptr_list));
201-
MultiTierStorage<K, V>::SetListsForCheckpoint(
202-
key_list_tmp, value_ptr_list, emb_config,
203-
key_list, value_list, version_list, freq_list);
204-
}
205-
return key_list->size();
187+
return Status::OK();
206188
}
207189

208190
Status Eviction(K* evict_ids, int64 evict_size) override {

Diff for: tensorflow/core/framework/embedding/dram_ssd_storage.h

+13-62
Original file line numberDiff line numberDiff line change
@@ -144,70 +144,21 @@ class DramSsdHashStorage : public MultiTierStorage<K, V> {
144144
return true;
145145
}
146146

147-
Status GetSnapshot(std::vector<K>* key_list,
148-
std::vector<ValuePtr<V>*>* value_ptr_list) override {
149-
{
150-
mutex_lock l(*(dram_->get_mutex()));
151-
TF_CHECK_OK(dram_->GetSnapshot(key_list, value_ptr_list));
152-
}
153-
{
154-
mutex_lock l(*(ssd_hash_->get_mutex()));
155-
TF_CHECK_OK(ssd_hash_->GetSnapshot(key_list, value_ptr_list));
156-
}
157-
return Status::OK();
158-
}
159-
160-
Status Shrink(const ShrinkArgs& shrink_args) override {
161-
dram_->Shrink(shrink_args);
162-
ssd_hash_->Shrink(shrink_args);
163-
return Status::OK();
164-
}
165-
166-
int64 GetSnapshot(std::vector<K>* key_list,
167-
std::vector<V* >* value_list,
168-
std::vector<int64>* version_list,
169-
std::vector<int64>* freq_list,
147+
Status Save(
148+
const string& tensor_name,
149+
const string& prefix,
150+
BundleWriter* writer,
170151
const EmbeddingConfig& emb_config,
171-
FilterPolicy<K, V, EmbeddingVar<K, V>>* filter,
172-
embedding::Iterator** it) override {
173-
{
174-
mutex_lock l(*(dram_->get_mutex()));
175-
std::vector<ValuePtr<V>*> value_ptr_list;
176-
std::vector<K> key_list_tmp;
177-
TF_CHECK_OK(dram_->GetSnapshot(&key_list_tmp, &value_ptr_list));
178-
MultiTierStorage<K, V>::SetListsForCheckpoint(
179-
key_list_tmp, value_ptr_list, emb_config,
180-
key_list, value_list, version_list, freq_list);
181-
}
182-
{
183-
mutex_lock l(*(ssd_hash_->get_mutex()));
184-
*it = ssd_hash_->GetIterator();
185-
}
186-
return key_list->size();
187-
}
152+
ShrinkArgs& shrink_args,
153+
int64 value_len,
154+
V* default_value) override {
155+
dram_->Save(tensor_name, prefix, writer, emb_config,
156+
shrink_args, value_len, default_value);
188157

189-
int64 GetSnapshotWithoutFetchPersistentEmb(
190-
std::vector<K>* key_list,
191-
std::vector<V*>* value_list,
192-
std::vector<int64>* version_list,
193-
std::vector<int64>* freq_list,
194-
const EmbeddingConfig& emb_config,
195-
SsdRecordDescriptor<K>* ssd_rec_desc) override {
196-
{
197-
mutex_lock l(*(dram_->get_mutex()));
198-
std::vector<ValuePtr<V>*> value_ptr_list;
199-
std::vector<K> temp_key_list;
200-
TF_CHECK_OK(dram_->GetSnapshot(&temp_key_list, &value_ptr_list));
201-
MultiTierStorage<K, V>::SetListsForCheckpoint(
202-
temp_key_list, value_ptr_list, emb_config,
203-
key_list, value_list, version_list,
204-
freq_list);
205-
}
206-
{
207-
mutex_lock l(*(ssd_hash_->get_mutex()));
208-
ssd_hash_->SetSsdRecordDescriptor(ssd_rec_desc);
209-
}
210-
return key_list->size() + ssd_rec_desc->key_list.size();
158+
ssd_hash_->Save(tensor_name, prefix, writer, emb_config,
159+
shrink_args, value_len, default_value);
160+
161+
return Status::OK();
211162
}
212163

213164
Status RestoreSSD(int64 emb_index, int64 emb_slot_num, int64 value_len,

Diff for: tensorflow/core/framework/embedding/embedding_config.h

+10
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,16 @@ struct EmbeddingConfig {
101101
return emb_index == primary_emb_index;
102102
}
103103

104+
bool is_save_freq() const {
105+
return filter_freq != 0 ||
106+
record_freq ||
107+
normal_fix_flag == 1;
108+
}
109+
110+
bool is_save_version() const {
111+
return steps_to_live != 0 || record_version;
112+
}
113+
104114
int64 total_num(int alloc_len) {
105115
return block_num *
106116
(1 + (1 - normal_fix_flag) * slot_num) *

Diff for: tensorflow/core/framework/embedding/embedding_var.h

+8-24
Original file line numberDiff line numberDiff line change
@@ -582,30 +582,14 @@ class EmbeddingVar : public ResourceBase {
582582
emb_config_, device, reader, this, filter_);
583583
}
584584

585-
int64 GetSnapshot(std::vector<K>* key_list,
586-
std::vector<V* >* value_list,
587-
std::vector<int64>* version_list,
588-
std::vector<int64>* freq_list,
589-
embedding::Iterator** it = nullptr) {
590-
// for Interface Compatible
591-
// TODO Multi-tiered Embedding should use iterator in 'GetSnapshot' caller
592-
embedding::Iterator* _it = nullptr;
593-
it = (it == nullptr) ? &_it : it;
594-
return storage_->GetSnapshot(
595-
key_list, value_list, version_list,
596-
freq_list, emb_config_, filter_, it);
597-
}
598-
599-
int64 GetSnapshotWithoutFetchPersistentEmb(
600-
std::vector<K>* key_list,
601-
std::vector<V*>* value_list,
602-
std::vector<int64>* version_list,
603-
std::vector<int64>* freq_list,
604-
SsdRecordDescriptor<K>* ssd_rec_desc) {
605-
return storage_->
606-
GetSnapshotWithoutFetchPersistentEmb(
607-
key_list, value_list, version_list,
608-
freq_list, emb_config_, ssd_rec_desc);
585+
Status Save(const string& tensor_name,
586+
const string& prefix,
587+
BundleWriter* writer,
588+
embedding::ShrinkArgs& shrink_args) {
589+
return storage_->Save(tensor_name, prefix,
590+
writer, emb_config_,
591+
shrink_args, value_len_,
592+
default_value_);
609593
}
610594

611595
mutex* mu() {

0 commit comments

Comments
 (0)