@@ -47,21 +47,21 @@ WindowedKeyValueCache::WindowedKeyValueCache(State& state)
47
47
48
48
for (int i = 0 ; i < layer_count_; ++i) {
49
49
key_caches_in_.push_back (
50
- OrtValue::CreateTensor (*model_. allocator_device_ , key_cache_shape_in_, type_));
50
+ OrtValue::CreateTensor (Allocator () , key_cache_shape_in_, type_));
51
51
std::fill_n (key_caches_in_[i]->GetTensorMutableData <uint8_t >(),
52
52
ElementCountFromShape (key_cache_shape_in_),
53
53
static_cast <uint8_t >(model_.config_ ->model .decoder .sliding_window ->pad_value ));
54
54
55
55
value_caches_in_.push_back (
56
- OrtValue::CreateTensor (*model_. allocator_device_ , value_cache_shape_in_, type_));
56
+ OrtValue::CreateTensor (Allocator () , value_cache_shape_in_, type_));
57
57
std::fill_n (value_caches_in_[i]->GetTensorMutableData <uint8_t >(),
58
58
ElementCountFromShape (value_cache_shape_in_),
59
59
static_cast <uint8_t >(model_.config_ ->model .decoder .sliding_window ->pad_value ));
60
60
61
61
key_caches_out_.push_back (
62
- OrtValue::CreateTensor (*model_. allocator_device_ , key_cache_shape_out_, type_));
62
+ OrtValue::CreateTensor (Allocator () , key_cache_shape_out_, type_));
63
63
value_caches_out_.push_back (
64
- OrtValue::CreateTensor (*model_. allocator_device_ , value_cache_shape_out_, type_));
64
+ OrtValue::CreateTensor (Allocator () , value_cache_shape_out_, type_));
65
65
}
66
66
}
67
67
@@ -187,7 +187,7 @@ void WindowedKeyValueCache::Update(DeviceSpan<int32_t> /* beam_indices */, int c
187
187
188
188
ThreadPool thread_pool{static_cast <size_t >(layer_count_)};
189
189
thread_pool.Compute ([&](size_t layer_idx) {
190
- std::unique_ptr<OrtValue> key_cache = OrtValue::CreateTensor (*model_. allocator_device_ , updated_key_cache_shape_in, type_);
190
+ std::unique_ptr<OrtValue> key_cache = OrtValue::CreateTensor (Allocator () , updated_key_cache_shape_in, type_);
191
191
192
192
uint8_t * key_cache_data = key_cache->GetTensorMutableData <uint8_t >();
193
193
uint8_t * key_cache_in_data = key_caches_in_[layer_idx]->GetTensorMutableData <uint8_t >();
@@ -213,9 +213,9 @@ void WindowedKeyValueCache::Update(DeviceSpan<int32_t> /* beam_indices */, int c
213
213
}
214
214
215
215
key_caches_in_[layer_idx] = std::move (key_cache);
216
- key_caches_out_[layer_idx] = OrtValue::CreateTensor (*model_. allocator_device_ , updated_key_cache_shape_out, type_);
216
+ key_caches_out_[layer_idx] = OrtValue::CreateTensor (Allocator () , updated_key_cache_shape_out, type_);
217
217
218
- std::unique_ptr<OrtValue> value_cache = OrtValue::CreateTensor (*model_. allocator_device_ , updated_value_cache_shape_in, type_);
218
+ std::unique_ptr<OrtValue> value_cache = OrtValue::CreateTensor (Allocator () , updated_value_cache_shape_in, type_);
219
219
220
220
uint8_t * value_cache_data = value_cache->GetTensorMutableData <uint8_t >();
221
221
uint8_t * value_cache_in_data = value_caches_in_[layer_idx]->GetTensorMutableData <uint8_t >();
@@ -241,7 +241,7 @@ void WindowedKeyValueCache::Update(DeviceSpan<int32_t> /* beam_indices */, int c
241
241
}
242
242
243
243
value_caches_in_[layer_idx] = std::move (value_cache);
244
- value_caches_out_[layer_idx] = OrtValue::CreateTensor (*model_. allocator_device_ , updated_value_cache_shape_out, type_);
244
+ value_caches_out_[layer_idx] = OrtValue::CreateTensor (Allocator () , updated_value_cache_shape_out, type_);
245
245
});
246
246
247
247
window_size_ = 1 ;
0 commit comments