Skip to content

Commit 68a6ea7

Browse files
committed
Fix merge conflicts
1 parent 4bcfa33 commit 68a6ea7

File tree

4 files changed

+12
-10
lines changed

4 files changed

+12
-10
lines changed

src/beam_search_scorer.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void BeamHypotheses::Add(std::span<int32_t> hypothesis, float sum_logprobs) {
2727
return;
2828
}
2929
} else {
30-
beams_used_++;
30+
beams_used_++;
3131
}
3232

3333
// Rotate existing elements over while the new element scores higher

src/generators.h

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
#include "leakcheck.h"
2828
#include "make_string.h"
29-
#include "smartptrs.h"
3029
#include "models/onnxruntime_api.h"
3130
#include "smartptrs.h"
3231
#include "models/debugging.h"

src/models/windowed_kv_cache.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,21 @@ WindowedKeyValueCache::WindowedKeyValueCache(State& state)
4747

4848
for (int i = 0; i < layer_count_; ++i) {
4949
key_caches_in_.push_back(
50-
OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_in_, type_));
50+
OrtValue::CreateTensor(Allocator(), key_cache_shape_in_, type_));
5151
std::fill_n(key_caches_in_[i]->GetTensorMutableData<uint8_t>(),
5252
ElementCountFromShape(key_cache_shape_in_),
5353
static_cast<uint8_t>(model_.config_->model.decoder.sliding_window->pad_value));
5454

5555
value_caches_in_.push_back(
56-
OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_in_, type_));
56+
OrtValue::CreateTensor(Allocator(), value_cache_shape_in_, type_));
5757
std::fill_n(value_caches_in_[i]->GetTensorMutableData<uint8_t>(),
5858
ElementCountFromShape(value_cache_shape_in_),
5959
static_cast<uint8_t>(model_.config_->model.decoder.sliding_window->pad_value));
6060

6161
key_caches_out_.push_back(
62-
OrtValue::CreateTensor(*model_.allocator_device_, key_cache_shape_out_, type_));
62+
OrtValue::CreateTensor(Allocator(), key_cache_shape_out_, type_));
6363
value_caches_out_.push_back(
64-
OrtValue::CreateTensor(*model_.allocator_device_, value_cache_shape_out_, type_));
64+
OrtValue::CreateTensor(Allocator(), value_cache_shape_out_, type_));
6565
}
6666
}
6767

@@ -187,7 +187,7 @@ void WindowedKeyValueCache::Update(DeviceSpan<int32_t> /* beam_indices */, int c
187187

188188
ThreadPool thread_pool{static_cast<size_t>(layer_count_)};
189189
thread_pool.Compute([&](size_t layer_idx) {
190-
std::unique_ptr<OrtValue> key_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_in, type_);
190+
std::unique_ptr<OrtValue> key_cache = OrtValue::CreateTensor(Allocator(), updated_key_cache_shape_in, type_);
191191

192192
uint8_t* key_cache_data = key_cache->GetTensorMutableData<uint8_t>();
193193
uint8_t* key_cache_in_data = key_caches_in_[layer_idx]->GetTensorMutableData<uint8_t>();
@@ -213,9 +213,9 @@ void WindowedKeyValueCache::Update(DeviceSpan<int32_t> /* beam_indices */, int c
213213
}
214214

215215
key_caches_in_[layer_idx] = std::move(key_cache);
216-
key_caches_out_[layer_idx] = OrtValue::CreateTensor(*model_.allocator_device_, updated_key_cache_shape_out, type_);
216+
key_caches_out_[layer_idx] = OrtValue::CreateTensor(Allocator(), updated_key_cache_shape_out, type_);
217217

218-
std::unique_ptr<OrtValue> value_cache = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_in, type_);
218+
std::unique_ptr<OrtValue> value_cache = OrtValue::CreateTensor(Allocator(), updated_value_cache_shape_in, type_);
219219

220220
uint8_t* value_cache_data = value_cache->GetTensorMutableData<uint8_t>();
221221
uint8_t* value_cache_in_data = value_caches_in_[layer_idx]->GetTensorMutableData<uint8_t>();
@@ -241,7 +241,7 @@ void WindowedKeyValueCache::Update(DeviceSpan<int32_t> /* beam_indices */, int c
241241
}
242242

243243
value_caches_in_[layer_idx] = std::move(value_cache);
244-
value_caches_out_[layer_idx] = OrtValue::CreateTensor(*model_.allocator_device_, updated_value_cache_shape_out, type_);
244+
value_caches_out_[layer_idx] = OrtValue::CreateTensor(Allocator(), updated_value_cache_shape_out, type_);
245245
});
246246

247247
window_size_ = 1;

src/models/windowed_kv_cache.h

+3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ struct WindowedKeyValueCache : KeyValueCache {
3131
void SlideAllLayers();
3232
void SlideLayers(std::span<const size_t> layer_indices);
3333

34+
DeviceInterface& Device() { return *model_.p_device_kvcache_; }
35+
Ort::Allocator& Allocator() { return model_.p_device_kvcache_->GetAllocator(); }
36+
3437
State& state_;
3538
const Model& model_{state_.model_};
3639
int layer_count_{};

0 commit comments

Comments
 (0)