Skip to content

Commit 991b01a

Browse files
committed
Change how this is done; just check input value
1 parent f598888 commit 991b01a

2 files changed

Lines changed: 11 additions & 6 deletions

File tree

tensorflow_quantum/core/ops/noise/tfq_noisy_samples.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include <stdlib.h>
1717

18+
#include <limits>
1819
#include <string>
1920

2021
#include "../qsim/lib/channel.h"
@@ -78,6 +79,12 @@ class TfqNoisySamplesOp : public tensorflow::OpKernel {
7879

7980
int num_samples = 0;
8081
OP_REQUIRES_OK(context, GetIndividualSample(context, &num_samples));
82+
OP_REQUIRES(
83+
context,
84+
num_samples >= 0 && num_samples < std::numeric_limits<int>::max(),
85+
tensorflow::errors::InvalidArgument(
86+
absl::StrCat("num_samples must be between 0 and ",
87+
std::numeric_limits<int>::max(), ".")));
8188

8289
// Construct qsim circuits.
8390
std::vector<NoisyQsimCircuit> qsim_circuits(programs.size(),
@@ -228,8 +235,7 @@ class TfqNoisySamplesOp : public tensorflow::OpKernel {
228235
num_threads, std::vector<long>(output_dim_batch_size, 0));
229236

230237
for (int i = 0; i < output_dim_batch_size; i++) {
231-
uint64_t p_reps =
232-
(static_cast<uint64_t>(num_samples) + num_threads - 1) / num_threads;
238+
int p_reps = (num_samples + num_threads - 1) / num_threads;
233239
offset_prefix_sum[0][i] = rep_offsets[0][i] + p_reps;
234240
for (int j = 1; j < num_threads; j++) {
235241
offset_prefix_sum[j][i] += offset_prefix_sum[j - 1][i];

tensorflow_quantum/core/src/util_balance_trajectory.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,9 @@ void BalanceTrajectory(const int& num_samples, const int& num_threads,
7575

7676
int prev_max_height = -1;
7777
for (size_t j = 0; j < (*thread_offsets)[0].size(); j++) {
78-
uint64_t run_ceiling =
79-
((static_cast<uint64_t>(num_samples) + num_threads - 1) / num_threads);
80-
uint64_t num_lo = num_threads * run_ceiling - num_samples;
81-
uint64_t num_hi = num_threads - num_lo;
78+
int run_ceiling = ((num_samples + num_threads - 1) / num_threads);
79+
int num_lo = num_threads * run_ceiling - num_samples;
80+
int num_hi = num_threads - num_lo;
8281
int cur_max = prev_max_height;
8382
for (int i = 0; i < num_threads; i++) {
8483
if (height[i] == cur_max && num_lo) {

0 commit comments

Comments
 (0)