Skip to content

Commit f63dab1

Browse files
add FLAGS_use_accuracy_compatible_kernel
1 parent b042d81 commit f63dab1

File tree

2 files changed

+54
-26
lines changed

2 files changed

+54
-26
lines changed

paddle/phi/kernels/cpu/randperm_kernel.cc

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@
1818
#include <cstdint>
1919
#include <limits>
2020

21+
#include "paddle/common/flags.h"
2122
#include "paddle/phi/core/kernel_registry.h"
2223

24+
COMMON_DECLARE_bool(use_accuracy_compatible_kernel);
25+
2326
namespace phi {
2427

2528
// ---------------------------------------------------------------------------
@@ -101,36 +104,48 @@ void RandpermKernel(const Context& dev_ctx,
101104
DenseTensor* out) {
102105
T* out_data = dev_ctx.template Alloc<T>(out);
103106

104-
// MT19937 engine with that seed so the random sequence is identical.
105-
uint64_t seed = dev_ctx.GetGenerator()->GetCurrentSeed();
106-
TorchMT19937Engine engine(seed);
107+
if (FLAGS_use_accuracy_compatible_kernel) {
108+
// MT19937 engine with that seed so the random sequence is identical.
109+
uint64_t seed = dev_ctx.GetGenerator()->GetCurrentSeed();
110+
TorchMT19937Engine engine(seed);
111+
112+
if (n < static_cast<int>(std::numeric_limits<uint32_t>::max() / 20)) {
113+
// For small n: classic Fisher-Yates shuffle using 32-bit random values
114+
for (int i = 0; i < n; ++i) {
115+
out_data[i] = static_cast<T>(i);
116+
}
117+
for (int i = 0; i < n - 1; i++) {
118+
int64_t z = engine() % (n - i);
119+
T save = out_data[i];
120+
out_data[i] = out_data[z + i];
121+
out_data[z + i] = save;
122+
}
123+
} else {
124+
// For large n: inside-out Fisher-Yates using 64-bit random values
125+
for (int i = 0; i < n; i++) {
126+
int64_t z = static_cast<int64_t>(engine.random64() % (i + 1));
127+
out_data[i] = out_data[z];
128+
out_data[z] = static_cast<T>(i);
129+
}
130+
}
107131

108-
if (n < static_cast<int>(std::numeric_limits<uint32_t>::max() / 20)) {
109-
// For small n: classic Fisher-Yates shuffle using 32-bit random values
132+
// Advance the generator state so that successive randperm calls within the
133+
// same run produce different results
134+
dev_ctx.GetGenerator()->SetCurrentSeed(engine());
135+
} else {
136+
int seed = 0;
137+
std::shared_ptr<std::mt19937_64> engine;
138+
if (seed) {
139+
engine = std::make_shared<std::mt19937_64>();
140+
engine->seed(seed);
141+
} else {
142+
engine = dev_ctx.GetGenerator()->GetCPUEngine();
143+
}
110144
for (int i = 0; i < n; ++i) {
111145
out_data[i] = static_cast<T>(i);
112146
}
113-
for (int i = 0; i < n - 1; i++) {
114-
int64_t z = engine() % (n - i);
115-
T save = out_data[i];
116-
out_data[i] = out_data[z + i];
117-
out_data[z + i] = save;
118-
}
119-
} else {
120-
// For large n: inside-out Fisher-Yates using 64-bit random values
121-
for (int i = 0; i < n; i++) {
122-
int64_t z = static_cast<int64_t>(engine.random64() % (i + 1));
123-
out_data[i] = out_data[z];
124-
out_data[z] = static_cast<T>(i);
125-
}
147+
std::shuffle(out_data, out_data + n, *engine);
126148
}
127-
128-
// Advance the generator state so that successive randperm calls within the
129-
// same run produce different results (mirrors torch's stateful generator
130-
// behaviour: torch's CPUGeneratorImpl advances its internal MT19937 engine
131-
// on every random()/random64() call, so consecutive ops see different
132-
// states).
133-
dev_ctx.GetGenerator()->SetCurrentSeed(engine());
134149
}
135150

136151
} // namespace phi

test/legacy_test/test_randperm_op.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,14 +535,27 @@ def test_pin_memory_error_cases(self):
535535
paddle.randperm([2, 3], device=paddle.CPUPlace(), pin_memory=True)
536536

537537

538-
class TestRandpermLargeN(unittest.TestCase):
538+
class TestRandperm_compatible(unittest.TestCase):
539539
"""Test randperm with large n to cover the inside-out Fisher-Yates
540540
path using 64-bit random values in CPU randperm_kernel.cc.
541541
The threshold is uint32_max / 20 = 214748364, so n >= 214748365
542542
triggers the large-n branch.
543543
"""
544544

545+
def test_small_n_cpu(self):
546+
paddle.set_flags({'FLAGS_use_accuracy_compatible_kernel': 1})
547+
n = 10
548+
with dygraph_guard():
549+
paddle.set_device("cpu")
550+
x = paddle.randperm(n, dtype="int32")
551+
data_np = x.numpy()
552+
self.assertEqual(data_np.shape, (n,))
553+
self.assertEqual(data_np.min(), 0)
554+
self.assertEqual(data_np.max(), n - 1)
555+
self.assertEqual(len(np.unique(data_np)), n)
556+
545557
def test_large_n_cpu(self):
558+
paddle.set_flags({'FLAGS_use_accuracy_compatible_kernel': 1})
546559
# uint32_max // 20 + 1 = 214748365, just exceeds the threshold
547560
n = 214748365
548561
with dygraph_guard():

0 commit comments

Comments
 (0)