|
18 | 18 | #include <cstdint> |
19 | 19 | #include <limits> |
20 | 20 |
|
| 21 | +#include "paddle/common/flags.h" |
21 | 22 | #include "paddle/phi/core/kernel_registry.h" |
22 | 23 |
|
| 24 | +COMMON_DECLARE_bool(use_accuracy_compatible_kernel); |
| 25 | + |
23 | 26 | namespace phi { |
24 | 27 |
|
25 | 28 | // --------------------------------------------------------------------------- |
@@ -101,36 +104,48 @@ void RandpermKernel(const Context& dev_ctx, |
101 | 104 | DenseTensor* out) { |
102 | 105 | T* out_data = dev_ctx.template Alloc<T>(out); |
103 | 106 |
|
104 | | - // MT19937 engine with that seed so the random sequence is identical. |
105 | | - uint64_t seed = dev_ctx.GetGenerator()->GetCurrentSeed(); |
106 | | - TorchMT19937Engine engine(seed); |
| 107 | + if (FLAGS_use_accuracy_compatible_kernel) { |
| 108 | + // MT19937 engine with that seed so the random sequence is identical. |
| 109 | + uint64_t seed = dev_ctx.GetGenerator()->GetCurrentSeed(); |
| 110 | + TorchMT19937Engine engine(seed); |
| 111 | + |
| 112 | + if (n < static_cast<int>(std::numeric_limits<uint32_t>::max() / 20)) { |
| 113 | + // For small n: classic Fisher-Yates shuffle using 32-bit random values |
| 114 | + for (int i = 0; i < n; ++i) { |
| 115 | + out_data[i] = static_cast<T>(i); |
| 116 | + } |
| 117 | + for (int i = 0; i < n - 1; i++) { |
| 118 | + int64_t z = engine() % (n - i); |
| 119 | + T save = out_data[i]; |
| 120 | + out_data[i] = out_data[z + i]; |
| 121 | + out_data[z + i] = save; |
| 122 | + } |
| 123 | + } else { |
| 124 | + // For large n: inside-out Fisher-Yates using 64-bit random values |
| 125 | + for (int i = 0; i < n; i++) { |
| 126 | + int64_t z = static_cast<int64_t>(engine.random64() % (i + 1)); |
| 127 | + out_data[i] = out_data[z]; |
| 128 | + out_data[z] = static_cast<T>(i); |
| 129 | + } |
| 130 | + } |
107 | 131 |
|
108 | | - if (n < static_cast<int>(std::numeric_limits<uint32_t>::max() / 20)) { |
109 | | - // For small n: classic Fisher-Yates shuffle using 32-bit random values |
| 132 | + // Advance the generator state so that successive randperm calls within the |
| 133 | + // same run produce different results |
| 134 | + dev_ctx.GetGenerator()->SetCurrentSeed(engine()); |
| 135 | + } else { |
| 136 | + int seed = 0; |
| 137 | + std::shared_ptr<std::mt19937_64> engine; |
| 138 | + if (seed) { |
| 139 | + engine = std::make_shared<std::mt19937_64>(); |
| 140 | + engine->seed(seed); |
| 141 | + } else { |
| 142 | + engine = dev_ctx.GetGenerator()->GetCPUEngine(); |
| 143 | + } |
110 | 144 | for (int i = 0; i < n; ++i) { |
111 | 145 | out_data[i] = static_cast<T>(i); |
112 | 146 | } |
113 | | - for (int i = 0; i < n - 1; i++) { |
114 | | - int64_t z = engine() % (n - i); |
115 | | - T save = out_data[i]; |
116 | | - out_data[i] = out_data[z + i]; |
117 | | - out_data[z + i] = save; |
118 | | - } |
119 | | - } else { |
120 | | - // For large n: inside-out Fisher-Yates using 64-bit random values |
121 | | - for (int i = 0; i < n; i++) { |
122 | | - int64_t z = static_cast<int64_t>(engine.random64() % (i + 1)); |
123 | | - out_data[i] = out_data[z]; |
124 | | - out_data[z] = static_cast<T>(i); |
125 | | - } |
| 147 | + std::shuffle(out_data, out_data + n, *engine); |
126 | 148 | } |
127 | | - |
128 | | - // Advance the generator state so that successive randperm calls within the |
129 | | - // same run produce different results (mirrors torch's stateful generator |
130 | | - // behaviour: torch's CPUGeneratorImpl advances its internal MT19937 engine |
131 | | - // on every random()/random64() call, so consecutive ops see different |
132 | | - // states). |
133 | | - dev_ctx.GetGenerator()->SetCurrentSeed(engine()); |
134 | 149 | } |
135 | 150 |
|
136 | 151 | } // namespace phi |
|
0 commit comments