-
Notifications
You must be signed in to change notification settings - Fork 239
Expand file tree
/
Copy pathweight_cache.h
More file actions
498 lines (409 loc) · 17.2 KB
/
weight_cache.h
File metadata and controls
498 lines (409 loc) · 17.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_WEIGHT_CACHE_H_
#define TENSORFLOW_LITE_DELEGATES_XNNPACK_WEIGHT_CACHE_H_
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "xnnpack.h" // from @XNNPACK
#include "tflite/c/common.h"
#include "tflite/delegates/xnnpack/file_util.h"
#include "tflite/delegates/xnnpack/mmap_handle.h"
#include "tflite/delegates/xnnpack/weight_cache_schema_generated.h"
// WARNING: the interface in this file is still under experimentation and WILL
// CHANGE. Do not rely on it.
// TFLite doesn't use absl hashing utilities.
namespace tflite {
namespace xnnpack {
// Reserved value to request the delegate to use an in-memory cache instead of
// saving it to disk.
//
// This is useful when disk space is not available or when having to manage the
// cache file freshness is too complicated and still provides the deduplication
// mechanism for constant buffers that are reused across graph signatures.
inline constexpr char kInMemoryCachePath[] = ":memory";
// This structure is written at the start of every cache file.
//
// When changing this structure or anything in the cache file layout,
// `kVersion` should be incremented by one.
//
// When creating a new cache file, `version` should be set to `kVersion`.
//
// When reading a cache file, the cache should be rejected if `version`
// doesn't match `kVersion`.
struct XNNPackCacheHeader {
enum : uint64_t { kInvalidHeader = 0, kVersion = 4 };
uint64_t version;
uint64_t buffer_list_offset;
uint64_t buffer_list_size;
uint8_t stale;
};
// Checks if the file at the given path is compatible with the current XNNPack
// weight cache.
bool IsCompatibleCacheFile(const char* path);
// Checks if the opened file is compatible with the current XNNPack weight
// cache.
//
// Position in the file may be changed during the function execution but is
// restored upon exiting.
//
// Note: the file descriptor must be open and valid.
bool IsCompatibleCacheFile(FileDescriptorView fd);
struct PackIdentifier {
enum { kNoId = SIZE_MAX };
uint64_t pack_algorithm_id = kNoId;
uint64_t weights_id = kNoId;
uint64_t bias_id = kNoId;
friend bool operator==(const PackIdentifier& a, const PackIdentifier& b) {
return a.pack_algorithm_id == b.pack_algorithm_id &&
a.weights_id == b.weights_id && a.bias_id == b.bias_id;
}
struct Hash {
size_t operator()(const PackIdentifier& p) const {
std::hash<uint64_t> hasher;
return hasher(p.pack_algorithm_id) ^ hasher(p.weights_id) ^
hasher(p.bias_id);
}
};
};
struct BufferLocation {
uint64_t offset;
uint64_t size;
static constexpr BufferLocation Invalid() { return {SIZE_MAX, SIZE_MAX}; }
constexpr bool IsInvalid() const {
constexpr BufferLocation invalid = Invalid();
return offset == invalid.offset && size == invalid.size;
}
};
// Provides storage to write the packed buffers to and saves those to disk.
//
// WARNING: the interface in this file is still under experimentation and WILL
// CHANGE. Do not rely on it.
class WeightCacheBuilder {
public:
WeightCacheBuilder() = default;
~WeightCacheBuilder() = default;
// Non-copyable.
WeightCacheBuilder(const WeightCacheBuilder&) = delete;
WeightCacheBuilder& operator=(const WeightCacheBuilder&) = delete;
// Moveable.
WeightCacheBuilder(WeightCacheBuilder&&);
WeightCacheBuilder& operator=(WeightCacheBuilder&&);
[[nodiscard /*Starting the builder may fail.*/]]
bool Start(const char* path, const FileDescriptor& fd);
[[nodiscard]]
bool IsStarted() const {
return fd_.IsValid();
}
[[nodiscard]]
bool IsBuilding() const {
return is_build_step_;
}
// Reopens the given file to add data to it.
//
// This should be only called from the weight cache provider.
[[nodiscard /*Starting a build step may fail.*/]]
bool StartBuildStep();
// Resets the builder, discarding any data that hasn't been written.
void Reset();
// Reserves space in the data buffer for the required size in bytes and
// returns the address of that space.
//
// Sets `last_reserve` to the offset from `buffer_data_`'s start and `n`.
//
// A call to `Reserve` should alway be followed by a call to `Append`.
[[nodiscard /*The pointer to reserved space should be used.*/]]
void* Reserve(size_t size);
// Adds a buffer to the cache.
//
// The buffer space must have been reserved before using `Reserve`. If not, a
// new call to `Reserve` will be done and the data will be copied over.
[[nodiscard /*The location to the appended data should be saved.*/]]
BufferLocation Append(PackIdentifier pack_id, const void* data, uint64_t size,
int fingerprint_id);
// Writes the flatbuffer to disk.
[[nodiscard /*Writing the weight cache can fail.*/]]
bool StopBuildStep();
// Get the offset in the cache file of the data written during the last step.
//
// This includes the buffers that were appended and the whole buffer mapping.
[[nodiscard]]
size_t LastBuildStepStart() const {
return build_segment_start_;
}
// Get the size of the data written during the last step.
//
// This includes the buffers that were appended and the whole buffer mapping.
[[nodiscard]]
size_t LastBuildStepSize() const {
return build_segment_size_;
}
// Returns the file descriptor.
FileDescriptorView GetFileDescriptor() const { return fd_; }
// Returns the capacity of the underlying reserved buffer.
//
// WARNING: this exposes class implementation details for testing purposes and
// may be removed at any time.
size_t capacity() const { return capacity_; }
// Returns the address of the underlying reserved buffer.
//
// YOU SHOULD BE GETTING THAT ADDRESS FROM THE `Reserve` FUNCTION.
//
// WARNING: this exposes class implementation details for testing purposes and
// may be removed at any time.
uint8_t* data() const { return data_.get(); }
private:
std::unique_ptr<uint8_t[]> data_ = nullptr;
cache::schema::BufferListT schema_;
size_t capacity_ = 0;
// Size of the data written between StartBuildStep and StopBuildStep.
size_t build_segment_size_ = 0;
// Offset in the cache file when StartBuildStep was called.
size_t build_segment_start_ = 0;
// The call to StopBuildStep may short circuit when nothing was written to the
// cache. To ensure a smooth reloading, we need to ensure that the file header
// is correct. This flag lets us know if that has happened.
bool first_write_done_ = false;
// File descriptor view.
FileDescriptorView fd_;
std::string file_path_;
std::atomic<bool> is_build_step_ = false;
};
// This class handles cache misses when a cache is loaded (vs. being built).
//
// It substitutes itself to the file by storing packed data in buffers
// allocated on the fly that aren't persisted to the cache file.
class CacheMissHandler {
public:
// Reserves space in the data buffer for the required size in bytes and
// returns the address of that space.
//
// A call to `Reserve` should alway be followed by a call to `Append`.
[[nodiscard /*The pointer to reserved space should be used.*/]]
void* Reserve(size_t size);
// Adds a buffer to the cache.
//
// The buffer space must have been reserved before using `Reserve`. If not, a
// new call to `Reserve` will be done and the data will be copied over.
[[nodiscard /*The location to the appended data should be saved.*/]]
BufferLocation Append(PackIdentifier pack_id, const void* data, uint64_t size,
int fingerprint_id);
// Sets the reference offset from which all the cache miss offsets will be
// derived from. This avoids overwriting an offset that's already in use.
void SetMinOffset(size_t min_offset) { min_offset_ = min_offset; }
// Checks if the number of Reserve and Append calls are consistent.
bool ConsistentReserveAppendCalls() const noexcept {
return reserve_count_ == append_count_;
}
// Checks if cache misses have happened.
bool HasCacheMisses() const noexcept { return append_count_; }
size_t BufferCount() const noexcept { return buffers_.size(); }
private:
struct Buffer {
// Unaligned buffer holding at least loc.size elements.
std::unique_ptr<uint8_t[]> data;
BufferLocation loc;
// `kMinAlignement`-aligned pointer within data, holding at least loc.size
// elements.
uint8_t* ptr;
// True if an Append operation matching this buffer was called.
bool used;
};
std::vector<Buffer> buffers_;
// Holds a location that is bigger than all of the locations held in the
// cache. This is to avoid clashing with the stored offsets.
size_t min_offset_ = 0;
size_t append_count_ = 0;
size_t reserve_count_ = 0;
};
// Allows XNNPack to directly load packed weights from disk instead of having to
// repack them every time.
//
// XNNPack kernels do not have knowledge of the TFLite context. The only thing
// they can access is the buffers address. We rely on the fact that the address
// provided by TFLite is unique in order to find out the buffer identifier.
//
// To use the cache you need to:
//
// - Map the buffer addresses to their identifier with `MapTensorIdentifiers`
// - Load the cache file.
// - Finalize the cache before calling the run functions of XNNPack (setup and
// reshape are ok).
class MMapWeightCacheProvider {
public:
MMapWeightCacheProvider() = default;
MMapWeightCacheProvider(const MMapWeightCacheProvider&) = delete;
MMapWeightCacheProvider& operator=(const MMapWeightCacheProvider&) = delete;
MMapWeightCacheProvider(MMapWeightCacheProvider&&);
MMapWeightCacheProvider& operator=(MMapWeightCacheProvider&&);
~MMapWeightCacheProvider();
// Changes the file path to save the cache to.
//
// WARNING: Can only be called if the cache isn't finalized.
void SetFilePath(const char* file_path);
const std::string& GetFilePath() const { return file_path_; }
// Tries to load the given file. If the file doesn't exist starts building the
// cache for it.
//
// If `fd` is provided, use that instead of reopening the file at the given
// path.
[[nodiscard /*Loading a cache file may fail.*/]]
bool LoadOrStartBuild(const char* file_path,
FileDescriptor fd = FileDescriptor());
[[nodiscard /*Starting to build a cache file may fail.*/]]
bool StartBuild(const char* file_path, FileDescriptor fd = FileDescriptor());
// If the cache is still being built, this signals that all of the building
// operations are done and that `CanStartBuildStep()` should now return
// `false`.
void StopBuild() { builder_.Reset(); }
// Sets the weight file path and loads it.
[[nodiscard /*Loading a cache file may fail.*/]]
bool Load(const std::string& path, FileDescriptor fd = FileDescriptor());
// Loads the weight cache previously set with `SetFilePath`.
[[nodiscard /*Loading cache data may fail.*/]]
bool Load();
// Attempts to lock the cache in memory. Only applicable when the OS supports
// memory locking and the cache is mapped.
[[nodiscard /*Locking cache data may fail.*/]]
bool LockMemory();
// Attempts to unlock the cache in memory. Only applicable when the OS
// supports memory locking and the cache is mapped and locked.
[[nodiscard /*Unlocking cache data may fail.*/]]
bool UnlockMemory();
// Checks if the cache is currently being built or if it was loaded from a
// file.
[[nodiscard]]
bool CanStartBuildStep() const {
return builder_.IsStarted();
};
// Prepares to add new data to the cache.
[[nodiscard /*Updating cache data may fail.*/]]
bool StartBuildStep();
// Prepares to use data that was added to the cache during a build step.
[[nodiscard /*Updating cache data may fail.*/]]
bool StopBuildStep();
// Creates the tensor map.
void MapTensorIdentifiers(
const TfLiteTensor* tensors, size_t size,
const std::unordered_map<size_t, size_t>& tensor_index_to_identifier);
// In case a constant buffer data needs to be moved for some reason, this will
// map the new buffer data to its identifier.
void RemapDataBuffer(const void* buffer, const void* new_buffer);
// Returns the offset of the buffer identified by `cache_key`.
//
// If the buffer isn't found, return SIZE_MAX.
[[nodiscard]]
size_t LookUp(const xnn_weights_cache_look_up_key* cache_key);
// Reserves space for a buffer of given size and returns a pointer to it.
//
// The buffer data should be filled and `LookUpOrInsert` should be immediately
// called.
[[nodiscard]]
void* ReserveSpace(size_t size);
// Returns the offset of the buffer identified by `cache_key`. If the lookup
// fails, inserts the span `[ptr, ptr+size)`.
//
// This should be called after ReserveSpace and `ptr` should be the result of
// that call with the given `size`.
//
// WARNING: The cache key cannot be null.
[[nodiscard]]
size_t LookUpOrInsert(const xnn_weights_cache_look_up_key* cache_key,
void* ptr, size_t size);
// Gets the pointer to the buffer at the given offset.
//
// WARNING: This requires the buffer to be finalized.
// WARNING: This does not check the validity of the passed offset.
void* OffsetToAddr(size_t offset);
// Releases the weight cache's memory.
void Release();
// Returns true if the underlying builder is ready to add weights to the
// cache.
[[nodiscard]]
bool IsBuilding() const {
return builder_.IsBuilding();
};
// Returns true if a file is mapped or a file path is set.
[[nodiscard]]
bool IsActive() const {
return !mmap_handles_.empty() || builder_.IsStarted();
};
// Returns the cache provider expected by XNNPack.
xnn_weights_cache_provider& GetCacheProvider() { return cache_provider_; }
// C interface: `xnn_weights_cache_provider` callback.
static size_t look_up(void* context,
const xnn_weights_cache_look_up_key* cache_key);
// C interface: `xnn_weights_cache_provider` callback.
static void* reserve_space(void* context, size_t n);
// C interface: `xnn_weights_cache_provider` callback.
static size_t look_up_or_insert(
void* context, const xnn_weights_cache_look_up_key* cache_key, void* ptr,
size_t size);
// C interface: `xnn_weights_cache_provider` callback.
static bool is_finalized(void* context);
// C interface: `xnn_weights_cache_provider` callback.
static void* offset_to_addr(void* context, size_t offset);
// C interface: `xnn_weights_cache_provider` callback.
static enum xnn_status delete_cache(void* context);
// C interface: `xnn_weights_cache_provider` callback.
static enum xnn_status alias_data(void* context, void* alias, void* original);
// Checks if caches misses have happened and updates the cache file stale
// flag.
bool WriteCacheMissFlag();
private:
// Hashes a cache key to lookup in `cache_key_to_identifier_`.
PackIdentifier BuildPackIdentifier(const xnn_weights_cache_look_up_key& key);
// Loads the data written by the last call to `builder_.BuildStepStop()`.
[[nodiscard /*Loading cache data may fail.*/]]
bool LoadLastBuildStep();
// Cache provider implementation for XNNPack.
xnn_weights_cache_provider cache_provider_{
/*context=*/this,
/*look_up=*/MMapWeightCacheProvider::look_up,
/*reserve_space=*/MMapWeightCacheProvider::reserve_space,
/*look_up_or_insert=*/MMapWeightCacheProvider::look_up_or_insert,
/*is_finalized=*/MMapWeightCacheProvider::is_finalized,
/*offset_to_addr=*/MMapWeightCacheProvider::offset_to_addr,
/*delete_cache=*/MMapWeightCacheProvider::delete_cache,
/*alias_data=*/MMapWeightCacheProvider::alias_data,
};
// Path to the cache file.
std::string file_path_;
// Maps buffer addresses to buffer identifiers.
std::unordered_map<const void*, uint64_t> buffer_address_to_identifier_;
std::unordered_map<const void*, const void*> buffer_remaps_;
// Maps cache request hashes to the buffer identifier.
std::unordered_multimap<PackIdentifier, BufferLocation, PackIdentifier::Hash>
cache_key_to_offset_;
// MMap allocation handler.
std::vector<MMapHandle> mmap_handles_;
// The offset to the first buffer data in the MMap allocation.
size_t mmap_buffer_base_offset_;
// Holds a file descriptor to the cache file.
FileDescriptor file_descriptor_;
// Used to build the cache.
WeightCacheBuilder builder_;
// Handles cache misses when not building the cache file.
CacheMissHandler cache_miss_handler_;
// Stores the loaded buffer addresses corresponding to the given offset in the
// cache file.
std::map<size_t, void*> offset_to_addr_;
};
} // namespace xnnpack
} // namespace tflite
#endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_WEIGHT_CACHE_H_