Skip to content

Commit e919596

Browse files
committed
[EmbeddingVar] Add KvResourceCleanUpOp.
Signed-off-by: chenbangduo.cbd <[email protected]>
1 parent cf16856 commit e919596

File tree

10 files changed

+141
-4
lines changed

10 files changed

+141
-4
lines changed

tensorflow/core/framework/embedding/embedding_var.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,12 @@ class EmbeddingVar : public ResourceBase {
733733
return filter_;
734734
}
735735

736+
void CleanUp() {
737+
if (emb_config_.is_primary() && emb_config_.primary_emb_index == 0) {
738+
storage_->CleanUp();
739+
}
740+
}
741+
736742
protected:
737743
~EmbeddingVar() override {
738744
// When dynamic dimension embedding is used,

tensorflow/core/framework/embedding/feature_descriptor_impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class NonFreqDescriptor: public BaseFreqDescriptor {
7474
public:
7575
int64 GetFreq(void* value_ptr) override {
7676
LOG(FATAL)<<"Can not get freq from NonFreqCounter.";
77+
return 0;
7778
}
7879

7980
BaseFreqDescriptor* Clone() override {

tensorflow/core/framework/embedding/multi_tier_storage.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ class MultiTierStorage : public Storage<K, V> {
248248
FeatureDescriptor<V>* hbm_feat_desc,
249249
FeatureDescriptor<V>* dram_feat_desc);
250250
#endif //GOOGL_CUDA
251+
252+
void CleanUp() {
253+
LOG(FATAL) << "Function [CleanUp] of MultiTierStorage is not implemented.";
254+
}
255+
251256
private:
252257
virtual Status EvictionWithDelayedDestroy(K* evict_ids, int64 evict_size) {}
253258

tensorflow/core/framework/embedding/single_tier_storage.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,18 @@ class SingleTierStorage : public Storage<K, V> {
307307
false/*to_dram*/, is_incr, restore_buff);
308308
return s;
309309
}
310+
311+
void CleanUp() override {
312+
std::vector<K> key_list;
313+
std::vector<void*> value_ptr_list;
314+
kv_->GetSnapshot(&key_list, &value_ptr_list);
315+
316+
int list_size = key_list.size();
317+
for (int i = 0; i < list_size; i++) {
318+
kv_->Remove(key_list[i]);
319+
feat_desc_->Deallocate(value_ptr_list[i]);
320+
}
321+
}
310322

311323
protected:
312324
virtual void Shrink(std::vector<K>& key_list,
@@ -453,6 +465,10 @@ class HbmStorage : public SingleTierStorage<K, V> {
453465
GPUHashTable<K, V>* HashTable() override {
454466
return SingleTierStorage<K, V>::kv_->HashTable();
455467
}
468+
469+
void CleanUp() override {
470+
LOG(FATAL) << "Function [CleanUp] of HbmStorage is not implemented.";
471+
}
456472
protected:
457473
Status RestoreFeatures(int64 key_num, int bucket_num, int64 partition_id,
458474
int64 partition_num, int64 value_len, bool is_filter,
@@ -495,6 +511,11 @@ class HbmStorageWithCpuKv: public SingleTierStorage<K, V> {
495511
Status TryInsert(K key, void* value_ptr) {
496512
return SingleTierStorage<K, V>::kv_->Insert(key, value_ptr);
497513
}
514+
515+
void CleanUp() override {
516+
LOG(FATAL) << "Function [CleanUp] of HbmStorageWithCPUKv is not implemented.";
517+
}
518+
498519
public:
499520
friend class HbmDramStorage<K, V>;
500521
friend class HbmDramSsdStorage<K, V>;
@@ -521,6 +542,10 @@ class PmemMemkindStorage : public SingleTierStorage<K, V> {
521542
}
522543
~PmemMemkindStorage() override {}
523544

545+
void CleanUp() override {
546+
LOG(FATAL) << "Function [CleanUp] of PmemMemkindStorage is not implemented.";
547+
}
548+
524549
TF_DISALLOW_COPY_AND_ASSIGN(PmemMemkindStorage);
525550
};
526551

@@ -537,6 +562,10 @@ class PmemLibpmemStorage : public SingleTierStorage<K, V> {
537562
return SingleTierStorage<K, V>::kv_->Commit(keys, value_ptr);
538563
}
539564

565+
void CleanUp() override {
566+
LOG(FATAL) << "Function [CleanUp] of PmemLibpmemStorage is not implemented.";
567+
}
568+
540569
TF_DISALLOW_COPY_AND_ASSIGN(PmemLibpmemStorage);
541570

542571
protected:
@@ -577,6 +606,11 @@ class LevelDBStore : public SingleTierStorage<K, V> {
577606
key_list, emb_index, value_len,
578607
leveldb_kv, SingleTierStorage<K, V>::feat_desc_);
579608
}
609+
610+
void CleanUp() override {
611+
LOG(FATAL) << "Function [CleanUp] of LevelDBStorage is not implemented.";
612+
}
613+
580614
public:
581615
friend class DramLevelDBStore<K, V>;
582616
};
@@ -646,6 +680,10 @@ class SsdHashStorage : public SingleTierStorage<K, V> {
646680
reinterpret_cast<SSDHashKV<K, V>*>(SingleTierStorage<K, V>::kv_);
647681
ssd_kv->SetSsdRecordDescriptor(ssd_rec_desc);
648682
}
683+
684+
void CleanUp() override {
685+
LOG(FATAL) << "Function [CleanUp] of SsdHashStorage is not implemented.";
686+
}
649687
public:
650688
friend class DramSsdHashStorage<K, V>;
651689
#if GOOGLE_CUDA

tensorflow/core/framework/embedding/storage.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ class Storage {
210210
return Status::OK();
211211
}
212212

213+
virtual void CleanUp() = 0;
214+
213215
protected:
214216
virtual Status RestoreSSD(int64 emb_index, int64 emb_slot_num,
215217
int64 value_len,

tensorflow/core/framework/variable.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ message VariableDef {
7676
bool is_embedding_var = 91;
7777

7878
string initialize_op_for_restore = 92;
79+
80+
string clean_up_op_name = 93;
7981
}
8082

8183
message SaveSliceInfoDef {

tensorflow/core/kernels/kv_variable_ops.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,5 +557,56 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS_GPU)
557557

558558
#undef REGISTER_KERNELS_ALL
559559
#undef REGISTER_KERNELS
560+
561+
template <typename TKey, typename TValue>
562+
class KvResourceCleanUpOp : public OpKernel {
563+
public:
564+
explicit KvResourceCleanUpOp(OpKernelConstruction* ctx)
565+
:OpKernel(ctx) {
566+
TF_CHECK_OK(ReadBoolFromEnvVar("ENABLE_EV_CLEAN_UP", false, &enable_));
567+
}
568+
569+
void Compute(OpKernelContext* ctx) {
570+
if (!enable_) {
571+
return;
572+
}
573+
574+
EmbeddingVar<TKey, TValue>* ev = nullptr;
575+
Status s = LookupResource(ctx, HandleFromInput(ctx, 0), &ev);
576+
577+
if (s.ok()) {
578+
ev->Unref();
579+
ev->CleanUp();
580+
}
581+
}
582+
583+
private:
584+
bool enable_;
585+
};
586+
587+
#define REGISTER_KERNELS(dev, ktype, vtype) \
588+
REGISTER_KERNEL_BUILDER(Name("KvResourceCleanUp") \
589+
.Device(DEVICE_##dev) \
590+
.TypeConstraint<ktype>("Tkeys") \
591+
.TypeConstraint<vtype>("dtype"), \
592+
KvResourceCleanUpOp<ktype, vtype>);
593+
594+
#define REGISTER_KERNELS_ALL(dev, type) \
595+
REGISTER_KERNELS(dev, int32, type) \
596+
REGISTER_KERNELS(dev, int64, type)
597+
598+
#define REGISTER_KERNELS_CPU(type) REGISTER_KERNELS_ALL(CPU, type)
599+
TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_CPU)
600+
#undef REGISTER_KERNELS_CPU
601+
602+
#if GOOGLE_CUDA
603+
#define REGISTER_KERNELS_GPU(type) REGISTER_KERNELS_ALL(GPU, type)
604+
TF_CALL_FLOAT_TYPES(REGISTER_KERNELS_GPU)
605+
#undef REGISTER_KERNELS_GPU
606+
#endif // End of macro GOOGLE_CUDA
607+
608+
#undef REGISTER_KERNELS_ALL
609+
#undef REGISTER_KERNELS
610+
560611
} // namespace tensorflow
561612

tensorflow/core/kernels/kv_variable_restore_ops.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,8 @@ class KvResourceImportV3Op: public AsyncOpKernel {
376376

377377
// EV should not be initialized at this time.
378378
if (ev->IsInitialized()) {
379-
LOG(ERROR) << "Import parameter for EV (" << name_string
380-
<< ") failed, this EV has already been initialized.";
379+
LOG(WARNING) << "EV (" << name_string
380+
<< ") has already been initialized.";
381381
}
382382

383383
auto do_compute = [this, context, file_name_string, ev,

tensorflow/core/ops/kv_variable_ops.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,4 +893,13 @@ REGISTER_OP("KvResourceLookupResource")
893893
})
894894
.Doc(R"doc()doc");
895895

896+
REGISTER_OP("KvResourceCleanUp")
897+
.Input("resource_handle: resource")
898+
.Attr("Tkeys: {int64, int32}")
899+
.Attr("dtype: type = DT_FLOAT")
900+
.SetShapeFn([](InferenceContext* c) {
901+
return Status::OK();
902+
})
903+
.Doc(R"doc()doc");
904+
896905
} // namespace tensorflow

tensorflow/python/ops/kv_variable_ops.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,17 @@ def _init_from_args(self,
382382
dtype=self._dtype))
383383
if initial_value is not None:
384384
with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
385-
with ops.control_dependencies(None if self._is_primary else [self._primary.initializer]):
385+
if self._is_primary:
386+
self._clean_up_op = \
387+
gen_kv_variable_ops.kv_resource_clean_up(self._handle, \
388+
Tkeys=self._invalid_key_type, dtype=self._dtype)
389+
else:
390+
self._clean_up_op = self._primary.clean_up_op
391+
392+
control_dep = [self._clean_up_op]
393+
if not self._is_primary:
394+
control_dep.append(self._primary.initializer)
395+
with ops.control_dependencies(control_dep):
386396
self._init_op = gen_kv_variable_ops.initialize_kv_variable_v2_op(
387397
self._handle,
388398
self._primary._handle,
@@ -450,7 +460,10 @@ def export(self):
450460

451461

452462
def create_init_op_for_restore(self, name, initial_value, invalid_key, rank):
453-
with ops.control_dependencies(None if self._is_primary else [self._primary._init_op_for_restore]):
463+
control_dep = [self._clean_up_op]
464+
if not self._is_primary:
465+
control_dep.append(self._primary._init_op_for_restore)
466+
with ops.control_dependencies(control_dep):
454467
self._initializer_for_restore = gen_kv_variable_ops.initialize_kv_variable_v2_op(
455468
self._handle,
456469
self._primary._handle,
@@ -494,6 +507,11 @@ def create_init_op_for_restore(self, name, initial_value, invalid_key, rank):
494507

495508
def need_counts(self):
496509
return (self._record_freq or (self._filter_freq > 0) or self._is_multi_tier)
510+
511+
@property
512+
def clean_up_op(self):
513+
return self._clean_up_op
514+
497515
@property
498516
def gather_op(self):
499517
return self._gather_op
@@ -585,6 +603,9 @@ def _init_from_proto(self, variable_def, import_scope=None):
585603
self._primary_handle = g.as_graph_element(
586604
ops.prepend_name_scope(
587605
primary_name, import_scope=import_scope))
606+
607+
self._clean_up_op = g.as_graph_element(ops.prepend_name_scope(
608+
variable_def.clean_up_op_name, import_scope=import_scope))
588609
self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
589610
self._invalid_key = -1
590611
self._steps_to_live = init_op.get_attr("steps_to_live")
@@ -913,6 +934,8 @@ def to_proto(self, export_scope=None):
913934
self._save_slice_info.to_proto(export_scope=export_scope))
914935
var_def.initialize_op_for_restore = ops.strip_name_scope(
915936
self._init_op_for_restore.name, export_scope)
937+
var_def.clean_up_op_name = \
938+
ops.strip_name_scope(self._clean_up_op.name, export_scope)
916939
return var_def
917940
else:
918941
return None

0 commit comments

Comments
 (0)