|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
| 18 | +#include <future> |
| 19 | +#include "arrow/acero/concurrent_queue_internal.h" |
18 | 20 | #include "arrow/acero/hash_join_node.h" |
19 | 21 | #include "arrow/acero/schema_util.h" |
20 | 22 | #include "arrow/testing/extension_type.h" |
21 | 23 | #include "arrow/testing/gtest_util.h" |
22 | 24 | #include "arrow/testing/matchers.h" |
23 | | - |
24 | 25 | using testing::Eq; |
25 | 26 |
|
26 | 27 | namespace arrow { |
@@ -184,5 +185,116 @@ TEST(FieldMap, ExtensionTypeHashJoin) { |
184 | 185 | EXPECT_EQ(i.get(0), 0); |
185 | 186 | } |
186 | 187 |
|
| 188 | +template <typename Queue> |
| 189 | +void ConcurrentQueueBasicTest(Queue& queue) { |
| 190 | + ASSERT_TRUE(queue.Empty()); |
| 191 | + queue.Push(1); |
| 192 | + ASSERT_FALSE(queue.Empty()); |
| 193 | + ASSERT_EQ(queue.Pop(), 1); |
| 194 | + ASSERT_TRUE(queue.Empty()); |
| 195 | + |
| 196 | + auto fut_pop = std::async(std::launch::async, [&]() { return queue.WaitAndPop(); }); |
| 197 | + ASSERT_EQ(fut_pop.wait_for(std::chrono::milliseconds(10)), std::future_status::timeout); |
| 198 | + queue.Push(2); |
| 199 | + queue.Push(3); |
| 200 | + queue.Push(4); |
| 201 | + ASSERT_EQ(fut_pop.wait_for(std::chrono::milliseconds(10)), std::future_status::ready); |
| 202 | + ASSERT_EQ(fut_pop.get(), 2); |
| 203 | + fut_pop = std::async(std::launch::async, [&]() { return queue.WaitAndPop(); }); |
| 204 | + ASSERT_EQ(fut_pop.wait_for(std::chrono::milliseconds(10)), std::future_status::ready); |
| 205 | + ASSERT_EQ(fut_pop.get(), 3); |
| 206 | + ASSERT_FALSE(queue.Empty()); |
| 207 | + ASSERT_EQ(queue.TryPop(), std::make_optional(4)); |
| 208 | + ASSERT_EQ(queue.TryPop(), std::nullopt); |
| 209 | + queue.Push(5); |
| 210 | + ASSERT_FALSE(queue.Empty()); |
| 211 | + ASSERT_EQ(queue.Front(), 5); |
| 212 | + ASSERT_FALSE(queue.Empty()); |
| 213 | + queue.Clear(); |
| 214 | + ASSERT_TRUE(queue.Empty()); |
| 215 | +} |
| 216 | + |
| 217 | +TEST(ConcurrentQueue, BasicTest) { |
| 218 | + ConcurrentQueue<int> queue; |
| 219 | + ConcurrentQueueBasicTest(queue); |
| 220 | +} |
| 221 | + |
| 222 | +class BackpressureTestExecNode : public ExecNode { |
| 223 | + public: |
| 224 | + BackpressureTestExecNode() : ExecNode(nullptr, {}, {}, nullptr) {} |
| 225 | + const char* kind_name() const { return "BackpressureTestNode"; } |
| 226 | + Status InputReceived(ExecNode* input, ExecBatch batch) override { |
| 227 | + return Status::NotImplemented("Test only node"); |
| 228 | + } |
| 229 | + Status InputFinished(ExecNode* input, int total_batches) override { |
| 230 | + return Status::NotImplemented("Test only node"); |
| 231 | + } |
| 232 | + Status StartProducing() override { return Status::NotImplemented("Test only node"); } |
| 233 | + |
| 234 | + protected: |
| 235 | + Status StopProducingImpl() override { |
| 236 | + stopped = true; |
| 237 | + return Status::OK(); |
| 238 | + } |
| 239 | + |
| 240 | + public: |
| 241 | + void PauseProducing(ExecNode* output, int32_t counter) override { paused = true; } |
| 242 | + void ResumeProducing(ExecNode* output, int32_t counter) override { paused = false; } |
| 243 | + bool paused{false}; |
| 244 | + bool stopped{false}; |
| 245 | +}; |
| 246 | + |
| 247 | +class TestBackpressureControl : public BackpressureControl { |
| 248 | + public: |
| 249 | + TestBackpressureControl(BackpressureTestExecNode* testNode) : testNode(testNode) {} |
| 250 | + virtual void Pause() { testNode->PauseProducing(nullptr, 0); } |
| 251 | + virtual void Resume() { testNode->ResumeProducing(nullptr, 0); } |
| 252 | + BackpressureTestExecNode* testNode; |
| 253 | +}; |
| 254 | + |
| 255 | +TEST(BackpressureConcurrentQueue, BasicTest) { |
| 256 | + BackpressureTestExecNode dummyNode; |
| 257 | + auto ctrl = std::make_unique<TestBackpressureControl>(&dummyNode); |
| 258 | + ASSERT_OK_AND_ASSIGN(auto handler, |
| 259 | + BackpressureHandler::Make(&dummyNode, 2, 4, std::move(ctrl))); |
| 260 | + BackpressureConcurrentQueue<int> queue(std::move(handler)); |
| 261 | + |
| 262 | + ConcurrentQueueBasicTest(queue); |
| 263 | + ASSERT_FALSE(dummyNode.paused); |
| 264 | + ASSERT_FALSE(dummyNode.stopped); |
| 265 | +} |
| 266 | + |
| 267 | +TEST(BackpressureConcurrentQueue, BackpressureTest) { |
| 268 | + BackpressureTestExecNode dummyNode; |
| 269 | + auto ctrl = std::make_unique<TestBackpressureControl>(&dummyNode); |
| 270 | + ASSERT_OK_AND_ASSIGN(auto handler, |
| 271 | + BackpressureHandler::Make(&dummyNode, 2, 4, std::move(ctrl))); |
| 272 | + BackpressureConcurrentQueue<int> queue(std::move(handler)); |
| 273 | + |
| 274 | + queue.Push(6); |
| 275 | + queue.Push(7); |
| 276 | + queue.Push(8); |
| 277 | + ASSERT_FALSE(dummyNode.paused); |
| 278 | + ASSERT_FALSE(dummyNode.stopped); |
| 279 | + queue.Push(9); |
| 280 | + ASSERT_TRUE(dummyNode.paused); |
| 281 | + ASSERT_FALSE(dummyNode.stopped); |
| 282 | + ASSERT_EQ(queue.Pop(), 6); |
| 283 | + ASSERT_TRUE(dummyNode.paused); |
| 284 | + ASSERT_FALSE(dummyNode.stopped); |
| 285 | + ASSERT_EQ(queue.Pop(), 7); |
| 286 | + ASSERT_FALSE(dummyNode.paused); |
| 287 | + ASSERT_FALSE(dummyNode.stopped); |
| 288 | + queue.Push(10); |
| 289 | + ASSERT_FALSE(dummyNode.paused); |
| 290 | + ASSERT_FALSE(dummyNode.stopped); |
| 291 | + queue.Push(11); |
| 292 | + ASSERT_TRUE(dummyNode.paused); |
| 293 | + ASSERT_FALSE(dummyNode.stopped); |
| 294 | + ASSERT_OK(queue.ForceShutdown()); |
| 295 | + ASSERT_FALSE(dummyNode.paused); |
| 296 | + ASSERT_TRUE(dummyNode.stopped); |
| 297 | +} |
| 298 | + |
187 | 299 | } // namespace acero |
188 | 300 | } // namespace arrow |
0 commit comments