-
Notifications
You must be signed in to change notification settings - Fork 591
Add feature evict for dram_kv_embedding_cache. #4187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ Deploy Preview for pytorch-fbgemm-docs ready!
To edit notification comments on pull requests, go to your Netlify project configuration. |
@TroyGarden has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
for different type of eviction policy, do we share the same metadata header structure? |
one more question about the eviction policy. can I configure max feature count or max memory threshold while enabling l2 norm based eviction? |
// 暂停淘汰过程,如果有进行中的任务返回true, 没有返回false | ||
// 在暂停阶段,判断淘汰是否完成 | ||
bool pause() { | ||
std::lock_guard<std::mutex> lock(mutex_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this lock required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We hope that feature elimination is thread-safe and not affected by the calling method. At present, I understand that all interfaces may not be guaranteed to be safe. For example, trigger may be called by one thread, and pause is called by another thread, so locking is required.
decay_rate_(decay_rate), | ||
threshold_(threshold) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the decay rate and threshold still need to be manually set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, there's no good way to set this automatically.
CounterBasedEvict(folly::CPUThreadPoolExecutor* executor, | ||
SynchronizedShardedMap<int64_t, float*>& kv_store, | ||
float decay_rate, | ||
int threshold) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have different threshold for different feature offset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
@@ -67,7 +87,11 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { | |||
} | |||
|
|||
void set(at::Tensor indices, at::Tensor weights, at::Tensor count) { | |||
return impl_->set(indices, weights, count); | |||
impl_->feature_evict_pause(); | |||
impl_->set(indices, weights, count); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to improve performance, we'll enable async set for dram backend similar to ssd here:
FBGEMM/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp
Line 141 in c512951
if (enable_async_update_) { |
so we cannot directly start evict right after set. If we want to do this, should call wait_util_filling_work_done inside of eviction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
no matter which eviction trigger is configured, are we able to start eviction when memory reaches threshold? |
int evict_trigger_mode, | ||
int evict_trigger_strategy, | ||
int64_t trigger_step_interval, | ||
uint32_t ttl, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is already a weight_ttl_in_hours, shall we drop that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Yes, we can continue to improve the trigger timing and confirm the all reduce method of size.
- OK
@@ -26,15 +26,34 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder { | |||
int64_t max_D, | |||
double uniform_init_lower, | |||
double uniform_init_upper, | |||
int evict_trigger_mode, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the torch pybind interface does not allow int or int32 type, let's change to int64_t?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
with this pr, the dram backend build failed, wondering if you can run build and unit test based on the latest branch? |
Some third-party packages cannot work properly in our environment, so we did not compile |
It sounds more reasonable. We can change it. |
7e1f44e
to
264e35d
Compare
264e35d
to
64d2558
Compare
Feature/feature evict supplement
feature evict metric fix duration
CI signals are also failing cc @ionuthristodorescu
|
let me try locally and see if that's the operator integration issue |
@emlin has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Looks like the rebase conflicts have to be addressed before I can import the change. @ArronHZG |
const std::optional<at::Tensor>& ttls_in_hour = std::nullopt, | ||
const std::optional<at::Tensor>& count_decay_rates = std::nullopt, | ||
const std::optional<at::Tensor>& l2_weight_thresholds = std::nullopt, | ||
const std::optional<at::Tensor>& embedding_dims = std::nullopt, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference between this and table_dims @115?
std::copy(weights[id_index] | ||
.template data_ptr<weight_type>(), | ||
weights[id_index] | ||
.template data_ptr<weight_type>() + | ||
weights[id_index].numel(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from here, seems the weight id is not copied to block, where will you copy weight id to the block?
static uint64_t get_key(const void* block) { | ||
return reinterpret_cast<const MetaHeader*>(block)->key; | ||
} | ||
static void set_key(void* block, uint64_t key) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems this method is not called?
int64_t key; // feature key 8bytes | ||
uint32_t timestamp; // 4 bytes,the unit is second, uint32 indicates a | ||
// range of over 120 years | ||
uint32_t count : 31; // only 31 bit is used, max value is 2147483647 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering are we able to make the key and count as optional. in the inference side, memory is more restricted, and key/count won't be used anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be done in the next step
: 0.0f; | ||
} | ||
DLOG(INFO) << fmt::format( | ||
"Shard {} completed: \n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is just one time execution is finished, not the whole shard finished right?
as we'll frequently call the process in every batch, log every time might be too many. like DLOG_EVERY_N(INFO, 1000) can sample logs
futures_.emplace_back(folly::via(executor_).thenValue( | ||
[this, shard_id](auto&&) { process_shard(shard_id); })); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering if it's too many submit if we evict in every batch?
In the L2 elimination method, count and timestamp are expected to take up no space, but the struct contains key and used information, which is at least 16 bytes under 8-byte alignment, and the space seems to be unable to be compressed any further. |
one more question is how do I get all cached value for one id, including key, meta header and weight + optimizer? |
also one question about the alignment. |
Feature Eviction
1. Background and Design Philosophy of Feature Eviction
Background
In sparse training scenarios, sample features exhibit significant frequency disparities. Low-frequency features not only contribute minimally to model training but also lead to memory inefficiencies and overfitting. Thus, a feature eviction strategy is necessary to prune unhelpful features and optimize memory usage.
Design Goals
Develop a generic feature eviction mechanism for
dram_kv_embedding_cache
that supports multiple eviction strategies, enabling efficient utilization of memory resources.Design Approach
2. Fundamental Features of Feature Eviction
2.1 Basic Eviction Strategies
Strategy Descriptions
Count-Based Eviction
Evict features based on their occurrence frequency (Count) in training data. Low-frequency features are deemed "useless". Supports decay mechanisms for handling time-varying data distributions (e.g., periodic business patterns).
count += 1
) on each feature update (capped at2^31-1
to prevent overflow).count < threshold
.Timestamp-Based Eviction
Track the last update timestamp (in seconds) of each feature and those inactive for extended periods.
timestamp = current_step
or system time) on each feature update.current_step - timestamp > ttl
.Combined Count+Timestamp Eviction (Hybrid Strategy)
Retain features that are both "frequently occurring" and "recently active" by evaluating:
(count < min_count) AND (current_step - timestamp > ttl)
.Thresholds are determined via feature distribution analysis (e.g., count-timestamp 2D histograms).
L2 Norm-Based Eviction (Contribution Strategy)
Measure feature contribution via the L2 norm of their embeddings; features with smaller norms (e.g., near-zero embeddings) are considered less impactful.
norm = sqrt(sum(emb^2))
) during eviction.norm < min_norm
.Structural Design
Implement strategies using polymorphism:
FeatureEvict
): Manages asynchronous task scheduling, state tracking, and sharded processing.CounterBasedEvict
/TimeBasedEvict
): Implement specific eviction logic (e.g., count decay, timestamp comparison).2.2 Eviction Trigger Mechanisms
trigger_evict()
2.3 Metadata Storage for Feature Eviction
Metadata Structure
MetaHeader
(16 bytes, memory-aligned): Uses bit fields to compresscounter
andused
fields, supporting all three eviction strategies.Storage Layout
key
: Feature ID (8 bytes).timestamp
: Last update timestamp (4 bytes, range 0–4,294,967,295).count
: Feature occurrence count (31 bits, max2^31-1
).used
: Memory block occupancy flag (1 bit).3. Overlap Implementation of Feature Eviction and Training
Core Interfaces
The
FeatureEvict
class provides these key interfaces forDramKVEmbeddingCache
to overlap eviction with training:trigger_evict()
: Initiate asynchronous eviction task.pause()
: Suspend ongoing eviction.resume()
: Resume paused eviction.is_evicting()
: Query eviction task status.Usage Pattern
Surround hashtable updates/queries with
pause()
andresume()
to temporarily halt eviction during critical operations, resuming during data loading, network forward/backward passes.Asynchronous Optimization Techniques
block_cursors
, and use atomic variables for rapid pause/resume control.wlock
) to ensure thread safety between eviction and training operations.