Skip to content

Commit 869b526

Browse files
allightcopybara-github
authored andcommitted
Add a 'Forced' state to lazy_dag_cache and children
This allows one to set a particular key to a known fixed value irrespective of the computations of other nodes. This Forced key acts as a break on updates propagating. PiperOrigin-RevId: 738582081
1 parent 4f73edf commit 869b526

5 files changed

+221
-10
lines changed

xls/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -1370,6 +1370,7 @@ cc_test(
13701370
"@com_google_absl//absl/container:inlined_vector",
13711371
"@com_google_absl//absl/log",
13721372
"@com_google_absl//absl/status",
1373+
"@com_google_absl//absl/status:status_matchers",
13731374
"@com_google_absl//absl/status:statusor",
13741375
"@com_google_absl//absl/strings",
13751376
"@com_google_absl//absl/strings:str_format",

xls/passes/lazy_dag_cache.h

+95-7
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,20 @@ namespace xls {
4343
// This class implements a cache with a simple state machine; each key has
4444
// either no recorded value (kUnknown), a value that may be out-of-date
4545
// (kUnverified), a value that is valid if all inputs prove to be up-to-date
46-
// (kInputsUnverified), or a value that is known to be current & correct
47-
// (kKnown).
46+
// (kInputsUnverified), a value that is known to be current & correct
47+
// (kKnown), or information with a forced value (kForced).
4848
//
4949
// When a `key` is queried, we query for the values for all of its inputs, and
5050
// then re-compute the value for `key` if absent or unverified. If `key` was in
5151
// state kUnverified & this does change its associated value, we mark any direct
5252
// users that were in state kInputsUnverified as kUnverified, since their inputs
5353
// have changed.
5454
//
55+
// Any node that is in the kForced state will remain in the same state
56+
// regardless of other changes to the graph. Changing the value a node is Forced
57+
// to will invalidate all its users and force a recalculation of downstream
58+
// values.
59+
//
5560
// NOTE: If `key` is in state kInputsUnverified after we queried all of its
5661
// inputs, then their values did not change, so we have verified that
5762
// `key`'s value is up-to-date without having to recompute it! This is the
@@ -93,8 +98,45 @@ class LazyDagCache {
9398
LazyDagCache<Key, Value>& operator=(LazyDagCache<Key, Value>&& other) =
9499
delete;
95100

101+
// Erase all knowledge of the values of all keys.
96102
void Clear() { cache_.clear(); }
97-
void Forget(const Key& key) { cache_.erase(key); }
103+
// Erase all knowledge of the value of all keys except for 'Forced' values.
104+
void ClearNonForced() {
105+
absl::erase_if(cache_, [](const auto& v) {
106+
return v.second.state != CacheState::kForced;
107+
});
108+
}
109+
110+
// Entirely remove knowledge of this key. This includes erasing any Forced
111+
// data.
112+
void Forget(const Key& key) {
113+
cache_.erase(key);
114+
for (const Key& user : provider_->GetUsers(key)) {
115+
MarkInputsUnverified(user);
116+
}
117+
}
118+
119+
// Set the key as having the immutable, authoritative 'value'.
120+
//
121+
// *This is a dangerous operation and should be used with care.* It tells the
122+
// cache to never call the ComputeValue callback and to instead consider
123+
// 'value' to be associated with 'key' now and forever. This knowledge may
124+
// only be removed by calling 'Forget' or 'Clear'.
125+
void SetForced(const Key& key, std::unique_ptr<Value> value) {
126+
cache_.insert_or_assign(key, CacheEntry{.state = CacheState::kForced,
127+
.value = std::move(value)});
128+
MarkUsersUnverified(key);
129+
}
130+
131+
// Set the key as having the immutable, authoritative 'value'.
132+
//
133+
// *This is a dangerous operation and should be used with care.* It tells the
134+
// cache to never call the ComputeValue callback and to instead consider
135+
// 'value' to be associated with 'key' now and forever. This knowledge may
136+
// only be removed by calling 'Forget' or 'Clear'.
137+
void SetForced(const Key& key, Value value) {
138+
SetForced(key, std::make_unique<Value>(std::move(value)));
139+
}
98140

99141
void AddUnverified(const Key& key, Value value) {
100142
cache_.insert_or_assign(
@@ -106,6 +148,15 @@ class LazyDagCache {
106148
.value = std::move(value)});
107149
}
108150

151+
// Request recomputation of any users of this key.
152+
//
153+
// Mark as full unverified to force recomputation.
154+
void MarkUsersUnverified(const Key& key) {
155+
for (const Key& user : provider_->GetUsers(key)) {
156+
MarkUnverified(user);
157+
}
158+
}
159+
109160
void MarkUnverified(const Key& key) {
110161
auto it = cache_.find(key);
111162
if (it == cache_.end()) {
@@ -115,6 +166,11 @@ class LazyDagCache {
115166
if (state == CacheState::kUnverified) {
116167
return;
117168
}
169+
if (state == CacheState::kForced) {
170+
VLOG(1) << "Mark unverified called on forced entry "
171+
<< provider_->GetName(key);
172+
return;
173+
}
118174
state = CacheState::kUnverified;
119175
for (const Key& user : provider_->GetUsers(key)) {
120176
if (GetCacheState(user) == CacheState::kKnown) {
@@ -155,7 +211,8 @@ class LazyDagCache {
155211
Value* QueryValue(const Key& key) {
156212
// If `key` is already known, return a pointer to the cached value.
157213
if (auto it = cache_.find(key);
158-
it != cache_.end() && it->second.state == CacheState::kKnown) {
214+
it != cache_.end() && (it->second.state == CacheState::kKnown ||
215+
it->second.state == CacheState::kForced)) {
159216
return it->second.value.get();
160217
}
161218

@@ -189,6 +246,7 @@ class LazyDagCache {
189246
return cached_value.get();
190247
}
191248

249+
CHECK_NE(state, CacheState::kForced);
192250
state = CacheState::kKnown;
193251
cached_value = std::make_unique<Value>(*std::move(new_value));
194252

@@ -208,7 +266,7 @@ class LazyDagCache {
208266
// all nodes in the DAG.
209267
absl::Status EagerlyPopulate(absl::Span<const Key> topo_sorted_keys) {
210268
for (const Key& key : topo_sorted_keys) {
211-
if (GetCacheState(key) == CacheState::kKnown) {
269+
if (GetNonForcedCacheState(key) == CacheState::kKnown) {
212270
continue;
213271
}
214272
std::vector<const Value*> input_values;
@@ -232,6 +290,8 @@ class LazyDagCache {
232290
// consistent regardless. This is an expensive operation, intended for use in
233291
// tests. `topo_sorted_keys` must be a topological sort of the keys for all
234292
// nodes in the DAG.
293+
//
294+
// Note that Forced values are always considered consistent.
235295
absl::Status CheckConsistency(absl::Span<const Key> topo_sorted_keys) const;
236296

237297
enum class CacheState : uint8_t {
@@ -241,6 +301,14 @@ class LazyDagCache {
241301
kUnverified,
242302
kInputsUnverified,
243303
kKnown,
304+
// A value has been provided via external information and should be
305+
// considered authoritatively known.
306+
//
307+
// This node can never be put to unverified state.
308+
//
309+
// It is possible that the value this key is forced to is one that cannot be
310+
// arrived at through the normal update sequence.
311+
kForced,
244312
};
245313

246314
template <typename Sink>
@@ -258,6 +326,9 @@ class LazyDagCache {
258326
case CacheState::kKnown:
259327
absl::Format(&sink, "KNOWN");
260328
return;
329+
case CacheState::kForced:
330+
absl::Format(&sink, "FORCED");
331+
return;
261332
}
262333
LOG(FATAL) << "Unknown CacheState: " << static_cast<int>(state);
263334
}
@@ -272,6 +343,14 @@ class LazyDagCache {
272343
}
273344
return it->second.state;
274345
}
346+
// Get the cache state with forced values being considered kKnown
347+
CacheState GetNonForcedCacheState(const Key& key) const {
348+
CacheState s = GetCacheState(key);
349+
if (s == CacheState::kForced) {
350+
return CacheState::kKnown;
351+
}
352+
return s;
353+
}
275354
const Value* GetCachedValue(const Key& key) const {
276355
auto it = cache_.find(key);
277356
if (it == cache_.end() || it->second.value == nullptr) {
@@ -320,9 +399,10 @@ absl::Status LazyDagCache<Key, Value>::CheckConsistency(
320399

321400
if (state == CacheState::kKnown) {
322401
for (const Key& input : provider_->GetInputs(key)) {
323-
XLS_RET_CHECK_EQ(GetCacheState(input), CacheState::kKnown)
402+
XLS_RET_CHECK_EQ(GetNonForcedCacheState(input), CacheState::kKnown)
324403
<< "Non-KNOWN input for KNOWN key " << provider_->GetName(key)
325-
<< ": " << provider_->GetName(input);
404+
<< ": " << provider_->GetName(input) << " (input is "
405+
<< GetCacheState(input) << ")";
326406
}
327407
}
328408
if (state == CacheState::kInputsUnverified) {
@@ -333,6 +413,8 @@ absl::Status LazyDagCache<Key, Value>::CheckConsistency(
333413
}
334414
}
335415

416+
// NB state FORCED & UNVERIFIED has no requirements on its inputs.
417+
336418
if (absl::c_any_of(provider_->GetInputs(key), [&](const Key& input) {
337419
return !correct_values.contains(input);
338420
})) {
@@ -341,6 +423,12 @@ absl::Status LazyDagCache<Key, Value>::CheckConsistency(
341423
continue;
342424
}
343425

426+
if (state == CacheState::kForced) {
427+
// Forced values are definitionally correct.
428+
correct_values.insert(key);
429+
continue;
430+
}
431+
344432
std::vector<const Value*> input_values;
345433
absl::Span<const Key> inputs = provider_->GetInputs(key);
346434
input_values.reserve(inputs.size());

xls/passes/lazy_node_info.h

+53-3
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ namespace xls {
5555
// This class implements a cache with a simple state machine; each node has
5656
// either no recorded information (kUnknown), information that may be
5757
// out-of-date (kUnverified), information that is valid if all inputs prove to
58-
// be up-to-date (kInputsUnverified), or information that is known to be current
59-
// & correct (kKnown).
58+
// be up-to-date (kInputsUnverified), information that is known to be current
59+
// & correct (kKnown), or information with a forced value (kForced).
6060
//
6161
// On populating with a FunctionBase, this class listens for change events. When
6262
// a node changes, we mark its information as potentially out-of-date
@@ -71,6 +71,12 @@ namespace xls {
7171
// information, we mark any direct users that were in state kInputsUnverified as
7272
// kUnverified, since their inputs have changed.
7373
//
74+
// Any node that is in the kForced state will remain in the same state
75+
// regardless of other changes to the graph. Changing the value a node is Forced
76+
// to will invalidate all its users.
77+
//
78+
// A node may not have both 'given' data and 'forced' data at the same time.
79+
//
7480
// NOTE: If `n` is in state kInputsUnverified after we queried all of its
7581
// operands, then their values did not change, so we have verified that
7682
// `n`'s information is up-to-date without having to recompute it! This is
@@ -177,13 +183,51 @@ class LazyNodeInfo
177183
return rf;
178184
}
179185

186+
// Set the node to a single immutable forced value.
187+
//
188+
// This is different from Givens since it is not combined with the calculated
189+
// values from earlier in the tree but instead considered a-priori known.
190+
//
191+
// Note that any forced value may not have a given associated with it as the
192+
// given value will be ignored.
193+
//
194+
// Care should be taken when using this since existing information is utterly
195+
// ignored. In general AddGiven is a better choice.
196+
absl::StatusOr<ReachedFixpoint> SetForced(Node* node,
197+
LeafTypeTree<Info> forced_ltt) {
198+
XLS_RET_CHECK(givens_.find(node) == givens_.end())
199+
<< node << " already has given information.";
200+
if (cache_.GetCacheState(node) == CacheState::kForced &&
201+
*cache_.GetCachedValue(node) == forced_ltt) {
202+
return ReachedFixpoint::Unchanged;
203+
}
204+
cache_.SetForced(node, std::move(forced_ltt));
205+
return ReachedFixpoint::Changed;
206+
}
207+
208+
absl::StatusOr<ReachedFixpoint> RemoveForced(Node* node) {
209+
XLS_RET_CHECK_EQ(cache_.GetCacheState(node), CacheState::kForced)
210+
<< node << " has no forced value associated with it.";
211+
if (cache_.GetCacheState(node) != CacheState::kForced) {
212+
return ReachedFixpoint::Unchanged;
213+
}
214+
cache_.Forget(node);
215+
return ReachedFixpoint::Changed;
216+
}
217+
180218
// Bind the node info to the given function.
181219
absl::StatusOr<ReachedFixpoint> Attach(FunctionBase* f) {
182220
return AttachWithGivens(f, {});
183221
}
184222

223+
// Set the value 'given' as being assumed for the given node. This data is
224+
// combined with the already calculated data. If one needs to set the state
225+
// directly to the given value use SetForced instead. A node may not have both
226+
// 'Forced' and 'Given' data associated with it at the same time.
185227
absl::StatusOr<ReachedFixpoint> AddGiven(Node* node,
186228
LeafTypeTree<Info> given_ltt) {
229+
XLS_RET_CHECK_NE(cache_.GetCacheState(node), CacheState::kForced)
230+
<< node << " has a forced value associated with it.";
187231
auto it = givens_.find(node);
188232
if (it == givens_.end()) {
189233
givens_.emplace(node, std::move(given_ltt));
@@ -207,9 +251,13 @@ class LazyNodeInfo
207251
cache_.MarkUnverified(node);
208252
return ReachedFixpoint::Changed;
209253
}
210-
ReachedFixpoint ReplaceGiven(Node* node, LeafTypeTree<Info> given) {
254+
255+
absl::StatusOr<ReachedFixpoint> ReplaceGiven(Node* node,
256+
LeafTypeTree<Info> given) {
211257
auto it = givens_.find(node);
212258
if (it == givens_.end()) {
259+
XLS_RET_CHECK_NE(cache_.GetCacheState(node), CacheState::kForced)
260+
<< node << " has a forced value associated with it.";
213261
givens_.emplace(node, std::move(given));
214262
cache_.MarkUnverified(node);
215263
return ReachedFixpoint::Changed;
@@ -287,6 +335,8 @@ class LazyNodeInfo
287335
// lazy query engines, checks that the current state of the cache is correct
288336
// where expected & consistent regardless. This is an expensive operation,
289337
// intended for use in tests.
338+
//
339+
// Note that Forced values are always considered consistent.
290340
absl::Status CheckCacheConsistency() const {
291341
XLS_RET_CHECK(f_ != nullptr) << "Unattached info";
292342
return cache_.CheckConsistency(TopoSort(f_));

xls/passes/lazy_query_engine.h

+21
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ class LazyQueryEngine : public QueryEngine {
136136
return node->function_base() == info_.bound_function();
137137
}
138138

139+
// Check that the query engine is consistent. Note that Forced values are
140+
// always considered consistent.
139141
absl::Status CheckConsistency() const override {
140142
return info_.CheckCacheConsistency();
141143
}
@@ -172,6 +174,25 @@ class LazyQueryEngine : public QueryEngine {
172174
}
173175
ReachedFixpoint RemoveGiven(Node* node) { return info_.RemoveGiven(node); }
174176

177+
// Set the node to a single immutable forced value.
178+
//
179+
// This is different from Givens since it is not combined with the calculated
180+
// values from earlier in the tree but instead considered a-priori known.
181+
//
182+
// Note that any forced value may not have a given associated with it as the
183+
// given value will be ignored.
184+
//
185+
// Care should be taken when using this since existing information is utterly
186+
// ignored. In general AddGiven is a better choice.
187+
absl::StatusOr<ReachedFixpoint> SetForced(Node* node,
188+
LeafTypeTree<Info> forced_ltt) {
189+
return info_.SetForced(node, std::move(forced_ltt));
190+
}
191+
// Removed forced information.
192+
absl::StatusOr<ReachedFixpoint> RemoveForced(Node* node) {
193+
return info_.RemoveForced(node);
194+
}
195+
175196
protected:
176197
virtual LeafTypeTree<Info> ComputeInfo(
177198
Node* node,

0 commit comments

Comments
 (0)