Skip to content

Add feature evict for dram_kv_embedding_cache. #4187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: main
Choose a base branch
from

Conversation

ArronHZG
Copy link
Contributor

@ArronHZG ArronHZG commented May 26, 2025

Feature Eviction

1. Background and Design Philosophy of Feature Eviction

Background
In sparse training scenarios, sample features exhibit significant frequency disparities. Low-frequency features not only contribute minimally to model training but also lead to memory inefficiencies and overfitting. Thus, a feature eviction strategy is necessary to prune unhelpful features and optimize memory usage.

Design Goals
Develop a generic feature eviction mechanism for dram_kv_embedding_cache that supports multiple eviction strategies, enabling efficient utilization of memory resources.

Design Approach

  • Strategy Selection: Implement mainstream eviction strategies such as count-based (LFU) and timestamp-based (LRU) using polymorphism for extensibility.
  • Trigger Mechanisms: Support manual triggers, fixed-step intervals, and memory threshold-based triggers for flexible configuration.
  • Metadata Storage: Integrate feature statistics (timestamps, counters, usage status) into memory block metadata to avoid additional memory overhead.
  • Execution Process: Overlap training and eviction asynchronously by pausing/resuming eviction tasks, minimizing impact on training efficiency.

2. Fundamental Features of Feature Eviction

2.1 Basic Eviction Strategies

Strategy Descriptions

  • Count-Based Eviction
    Evict features based on their occurrence frequency (Count) in training data. Low-frequency features are deemed "useless". Supports decay mechanisms for handling time-varying data distributions (e.g., periodic business patterns).

    • Increment count (count += 1) on each feature update (capped at 2^31-1 to prevent overflow).
    • Apply decay during eviction and features where count < threshold.
  • Timestamp-Based Eviction
    Track the last update timestamp (in seconds) of each feature and those inactive for extended periods.

    • Update timestamp (timestamp = current_step or system time) on each feature update.
    • Evict features where current_step - timestamp > ttl.
  • Combined Count+Timestamp Eviction (Hybrid Strategy)
    Retain features that are both "frequently occurring" and "recently active" by evaluating:
    (count < min_count) AND (current_step - timestamp > ttl).
    Thresholds are determined via feature distribution analysis (e.g., count-timestamp 2D histograms).

  • L2 Norm-Based Eviction (Contribution Strategy)
    Measure feature contribution via the L2 norm of their embeddings; features with smaller norms (e.g., near-zero embeddings) are considered less impactful.

    • Compute L2 norm (norm = sqrt(sum(emb^2))) during eviction.
    • Evict features where norm < min_norm.
    • Computational complexity: O(d) per feature (d = embedding dimension).

Structural Design
Implement strategies using polymorphism:

  • Base Class (FeatureEvict): Manages asynchronous task scheduling, state tracking, and sharded processing.
  • Subclasses (CounterBasedEvict/TimeBasedEvict): Implement specific eviction logic (e.g., count decay, timestamp comparison).

2.2 Eviction Trigger Mechanisms

Trigger Type Implementation
Manual Trigger Explicit call via trigger_evict()
Fixed Step Trigger Periodic invocation within training loop
Memory Threshold Automatic trigger based on memory usage monitoring

2.3 Metadata Storage for Feature Eviction

Metadata Structure

  • MetaHeader (16 bytes, memory-aligned): Uses bit fields to compress counter and used fields, supporting all three eviction strategies.

metaheader

Storage Layout

  • key: Feature ID (8 bytes).
  • timestamp: Last update timestamp (4 bytes, range 0–4,294,967,295).
  • count: Feature occurrence count (31 bits, max 2^31-1).
  • used: Memory block occupancy flag (1 bit).

3. Overlap Implementation of Feature Eviction and Training

Core Interfaces
The FeatureEvict class provides these key interfaces for DramKVEmbeddingCache to overlap eviction with training:

  • trigger_evict(): Initiate asynchronous eviction task.
  • pause(): Suspend ongoing eviction.
  • resume(): Resume paused eviction.
  • is_evicting(): Query eviction task status.

Usage Pattern
Surround hashtable updates/queries with pause() and resume() to temporarily halt eviction during critical operations, resuming during data loading, network forward/backward passes.

pipeline

Asynchronous Optimization Techniques

  • Sharded Parallel Processing: Divide eviction tasks by shard, leveraging the hashtable's thread pool for parallel execution to minimize resource contention.
  • Fine-Grained Interruption: Check interrupt flags after processing each memory block, track progress with block_cursors, and use atomic variables for rapid pause/resume control.
  • Lock Optimization: Employ read-write locks (wlock) to ensure thread safety between eviction and training operations.

Copy link

netlify bot commented May 26, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit d394af9
🔍 Latest deploy log https://app.netlify.com/projects/pytorch-fbgemm-docs/deploys/68413b41ec72a10008131597
😎 Deploy Preview https://deploy-preview-4187--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify project configuration.

@facebook-github-bot
Copy link
Contributor

@TroyGarden has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@emlin
Copy link
Contributor

emlin commented May 30, 2025

for different type of eviction policy, do we share the same metadata header structure?
for example, if we choose to use feature l2 norm based eviction, do we still need to reserve 4 bytes for timestamp?

@emlin
Copy link
Contributor

emlin commented May 30, 2025

one more question about the eviction policy. can I configure max feature count or max memory threshold while enabling l2 norm based eviction?

// 暂停淘汰过程,如果有进行中的任务返回true, 没有返回false
// 在暂停阶段,判断淘汰是否完成
bool pause() {
std::lock_guard<std::mutex> lock(mutex_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this lock required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We hope that feature elimination is thread-safe and not affected by the calling method. At present, I understand that all interfaces may not be guaranteed to be safe. For example, trigger may be called by one thread, and pause is called by another thread, so locking is required.

Comment on lines 166 to 167
decay_rate_(decay_rate),
threshold_(threshold) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the decay rate and threshold still need to be manually set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there's no good way to set this automatically.

CounterBasedEvict(folly::CPUThreadPoolExecutor* executor,
SynchronizedShardedMap<int64_t, float*>& kv_store,
float decay_rate,
int threshold)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have different threshold for different feature offset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

@@ -67,7 +87,11 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
}

void set(at::Tensor indices, at::Tensor weights, at::Tensor count) {
return impl_->set(indices, weights, count);
impl_->feature_evict_pause();
impl_->set(indices, weights, count);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to improve performance, we'll enable async set for dram backend similar to ssd here:

so we cannot directly start evict right after set. If we want to do this, should call wait_util_filling_work_done inside of eviction

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

@emlin
Copy link
Contributor

emlin commented Jun 1, 2025

no matter which eviction trigger is configured, are we able to start eviction when memory reaches threshold?

int evict_trigger_mode,
int evict_trigger_strategy,
int64_t trigger_step_interval,
uint32_t ttl,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is already a weight_ttl_in_hours, shall we drop that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yes, we can continue to improve the trigger timing and confirm the all reduce method of size.
  2. OK

@@ -26,15 +26,34 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
int64_t max_D,
double uniform_init_lower,
double uniform_init_upper,
int evict_trigger_mode,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the torch pybind interface does not allow int or int32 type, let's change to int64_t?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

@emlin
Copy link
Contributor

emlin commented Jun 2, 2025

with this pr, the dram backend build failed, wondering if you can run build and unit test based on the latest branch?

@ArronHZG
Copy link
Contributor Author

ArronHZG commented Jun 3, 2025

Some third-party packages cannot work properly in our environment, so we did not compile
dram_kv_embedding_cache.h
dram_kv_embedding_cache_wrapper.h, but directly rewrote the code. The compilation errors can be shown below, let's take a look. @emlin

@ArronHZG
Copy link
Contributor Author

ArronHZG commented Jun 3, 2025

no matter which eviction trigger is configured, are we able to start eviction when memory reaches threshold?

It sounds more reasonable. We can change it.

@ArronHZG ArronHZG force-pushed the feature/feature_evict branch 3 times, most recently from 7e1f44e to 264e35d Compare June 4, 2025 02:43
@ArronHZG ArronHZG changed the title feature evict Add feature evict for dram_kv_embedding_cache. Jun 4, 2025
@ArronHZG ArronHZG force-pushed the feature/feature_evict branch from 264e35d to 64d2558 Compare June 4, 2025 02:47
@kausv
Copy link

kausv commented Jun 4, 2025

CI signals are also failing cc @ionuthristodorescu

10:16:30 AM: /opt/buildhome/miniconda/envs/build_docs/lib/python3.13/site-packages/fbgemm_gpu/split_table_batched_embeddings_ops_training.py:docstring of fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen:219: ERROR: Unexpected indentation. 10:16:30 AM: /opt/buildhome/miniconda/envs/build_docs/lib/python3.13/site-packages/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py:docstring of fbgemm_gpu.split_table_batched_embeddings_ops_inference.IntNBitTableBatchedEmbeddingBagsCodegen.split_embedding_weights_with_scale_bias:3: ERROR: Unexpected indentation.

@emlin
Copy link
Contributor

emlin commented Jun 4, 2025

CI signals are also failing cc @ionuthristodorescu

10:16:30 AM: /opt/buildhome/miniconda/envs/build_docs/lib/python3.13/site-packages/fbgemm_gpu/split_table_batched_embeddings_ops_training.py:docstring of fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen:219: ERROR: Unexpected indentation. 10:16:30 AM: /opt/buildhome/miniconda/envs/build_docs/lib/python3.13/site-packages/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py:docstring of fbgemm_gpu.split_table_batched_embeddings_ops_inference.IntNBitTableBatchedEmbeddingBagsCodegen.split_embedding_weights_with_scale_bias:3: ERROR: Unexpected indentation.

let me try locally and see if that's the operator integration issue

@facebook-github-bot
Copy link
Contributor

@emlin has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@emlin
Copy link
Contributor

emlin commented Jun 4, 2025

CI signals are also failing cc @ionuthristodorescu
10:16:30 AM: /opt/buildhome/miniconda/envs/build_docs/lib/python3.13/site-packages/fbgemm_gpu/split_table_batched_embeddings_ops_training.py:docstring of fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen:219: ERROR: Unexpected indentation. 10:16:30 AM: /opt/buildhome/miniconda/envs/build_docs/lib/python3.13/site-packages/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py:docstring of fbgemm_gpu.split_table_batched_embeddings_ops_inference.IntNBitTableBatchedEmbeddingBagsCodegen.split_embedding_weights_with_scale_bias:3: ERROR: Unexpected indentation.

let me try locally and see if that's the operator integration issue

Looks like the rebase conflicts have to be addressed before I can import the change. @ArronHZG

const std::optional<at::Tensor>& ttls_in_hour = std::nullopt,
const std::optional<at::Tensor>& count_decay_rates = std::nullopt,
const std::optional<at::Tensor>& l2_weight_thresholds = std::nullopt,
const std::optional<at::Tensor>& embedding_dims = std::nullopt,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between this and table_dims @115?

Comment on lines +340 to +344
std::copy(weights[id_index]
.template data_ptr<weight_type>(),
weights[id_index]
.template data_ptr<weight_type>() +
weights[id_index].numel(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from here, seems the weight id is not copied to block, where will you copy weight id to the block?

static uint64_t get_key(const void* block) {
return reinterpret_cast<const MetaHeader*>(block)->key;
}
static void set_key(void* block, uint64_t key) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems this method is not called?

int64_t key; // feature key 8bytes
uint32_t timestamp; // 4 bytes,the unit is second, uint32 indicates a
// range of over 120 years
uint32_t count : 31; // only 31 bit is used, max value is 2147483647
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering are we able to make the key and count as optional. in the inference side, memory is more restricted, and key/count won't be used anyway

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be done in the next step

: 0.0f;
}
DLOG(INFO) << fmt::format(
"Shard {} completed: \n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just one time execution is finished, not the whole shard finished right?
as we'll frequently call the process in every batch, log every time might be too many. like DLOG_EVERY_N(INFO, 1000) can sample logs

Comment on lines +187 to +188
futures_.emplace_back(folly::via(executor_).thenValue(
[this, shard_id](auto&&) { process_shard(shard_id); }));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if it's too many submit if we evict in every batch?

@ArronHZG
Copy link
Contributor Author

ArronHZG commented Jun 5, 2025

for different type of eviction policy, do we share the same metadata header structure? for example, if we choose to use feature l2 norm based eviction, do we still need to reserve 4 bytes for timestamp?

In the L2 elimination method, count and timestamp are expected to take up no space, but the struct contains key and used information, which is at least 16 bytes under 8-byte alignment, and the space seems to be unable to be compressed any further.

@emlin
Copy link
Contributor

emlin commented Jun 5, 2025

one more question is how do I get all cached value for one id, including key, meta header and weight + optimizer?
that method is needed to checkpoint including metadata

@emlin
Copy link
Contributor

emlin commented Jun 5, 2025

also one question about the alignment.
if use the backend in inference side, the weight size may not be 8 bytes aligned. all these are about memory saving, not urgent immediately, but after we merged this pr, we can think about how compress a bit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants