|
31 | 31 | sys.path.append("../common")
|
32 | 32 |
|
33 | 33 | import os
|
| 34 | +import random |
34 | 35 | import threading
|
35 | 36 | import time
|
36 | 37 | import unittest
|
37 | 38 | from builtins import str
|
| 39 | +from functools import partial |
38 | 40 |
|
39 | 41 | import numpy as np
|
40 | 42 | import sequence_util as su
|
@@ -3432,5 +3434,185 @@ def test_send_request_after_timeout(self):
|
3432 | 3434 | raise last_err
|
3433 | 3435 |
|
3434 | 3436 |
|
| 3437 | +class SequenceBatcherPreserveOrderingTest(su.SequenceBatcherTestUtil): |
| 3438 | + def setUp(self): |
| 3439 | + super().setUp() |
| 3440 | + # By default, find tritonserver on "localhost", but can be overridden |
| 3441 | + # with TRITONSERVER_IPADDR envvar |
| 3442 | + self.server_address_ = ( |
| 3443 | + os.environ.get("TRITONSERVER_IPADDR", "localhost") + ":8001" |
| 3444 | + ) |
| 3445 | + |
| 3446 | + # Prepare input and expected output based on the model and |
| 3447 | + # the infer sequence sent for testing. If the test is to be extended |
| 3448 | + # for different sequence and model, then proper grouping should be added |
| 3449 | + self.model_name_ = "sequence_py" |
| 3450 | + self.tensor_data_ = np.ones(shape=[1, 1], dtype=np.int32) |
| 3451 | + self.inputs_ = [grpcclient.InferInput("INPUT0", [1, 1], "INT32")] |
| 3452 | + self.inputs_[0].set_data_from_numpy(self.tensor_data_) |
| 3453 | + self.triton_client = grpcclient.InferenceServerClient(self.server_address_) |
| 3454 | + |
| 3455 | + # Atomic request ID for multi-threaded inference |
| 3456 | + self.request_id_lock = threading.Lock() |
| 3457 | + self.request_id = 1 |
| 3458 | + |
| 3459 | + def send_sequence(self, seq_id, seq_id_map, req_id_map): |
| 3460 | + if seq_id not in seq_id_map: |
| 3461 | + seq_id_map[seq_id] = [] |
| 3462 | + |
| 3463 | + start, middle, end = (True, False), (False, False), (False, True) |
| 3464 | + # Send sequence with 1 start, 1 middle, and 1 end request |
| 3465 | + seq_flags = [start, middle, end] |
| 3466 | + for start_flag, end_flag in seq_flags: |
| 3467 | + # Introduce random sleep to better interweave requests from different sequences |
| 3468 | + time.sleep(random.uniform(0.0, 1.0)) |
| 3469 | + |
| 3470 | + # Serialize sending requests to ensure ordered request IDs |
| 3471 | + with self.request_id_lock: |
| 3472 | + req_id = self.request_id |
| 3473 | + self.request_id += 1 |
| 3474 | + |
| 3475 | + # Store metadata to validate results later |
| 3476 | + req_id_map[req_id] = seq_id |
| 3477 | + seq_id_map[seq_id].append(req_id) |
| 3478 | + |
| 3479 | + self.triton_client.async_stream_infer( |
| 3480 | + self.model_name_, |
| 3481 | + self.inputs_, |
| 3482 | + sequence_id=seq_id, |
| 3483 | + sequence_start=start_flag, |
| 3484 | + sequence_end=end_flag, |
| 3485 | + timeout=None, |
| 3486 | + request_id=str(req_id), |
| 3487 | + ) |
| 3488 | + |
| 3489 | + def _test_sequence_ordering(self, preserve_ordering, decoupled): |
| 3490 | + # 1. Send a few grpc streaming sequence requests to the model. |
| 3491 | + # 2. With grpc streaming, the model should receive the requests in |
| 3492 | + # the same order they are sent from client, and the client should |
| 3493 | + # receive the responses in the same order sent back by the |
| 3494 | + # model/server. With sequence scheduler, the requests for each sequence should be routed to the same model |
| 3495 | + # instance, and no two requests from the same sequence should |
| 3496 | + # get batched together. |
| 3497 | + # 3. With preserve_ordering=False, we may get the responses back in a different |
| 3498 | + # order than the requests, but with grpc streaming we should still expect responses for each sequence to be ordered. |
| 3499 | + # 4. Assert that the sequence values are ordered, and that the response IDs per sequence are ordered |
| 3500 | + class SequenceResult: |
| 3501 | + def __init__(self, seq_id, result, request_id): |
| 3502 | + self.seq_id = seq_id |
| 3503 | + self.result = result |
| 3504 | + self.request_id = int(request_id) |
| 3505 | + |
| 3506 | + def full_callback(sequence_dict, sequence_list, result, error): |
| 3507 | + # We expect no model errors for this test |
| 3508 | + if error: |
| 3509 | + self.assertTrue(False, error) |
| 3510 | + |
| 3511 | + # Gather all the necessary metadata for validation |
| 3512 | + request_id = int(result.get_response().id) |
| 3513 | + sequence_id = request_id_map[request_id] |
| 3514 | + # Overall list of results in the order received, regardless of sequence ID |
| 3515 | + sequence_list.append(SequenceResult(sequence_id, result, request_id)) |
| 3516 | + # Ordered results organized by their seq IDs |
| 3517 | + sequence_dict[sequence_id].append(result) |
| 3518 | + |
| 3519 | + # Store ordered list in which responses are received by client |
| 3520 | + sequence_list = [] |
| 3521 | + # Store mapping of sequence ID to response results |
| 3522 | + sequence_dict = {} |
| 3523 | + # Store mapping of sequence ID to request IDs and vice versa |
| 3524 | + sequence_id_map = {} |
| 3525 | + request_id_map = {} |
| 3526 | + |
| 3527 | + # Start stream |
| 3528 | + seq_callback = partial(full_callback, sequence_dict, sequence_list) |
| 3529 | + self.triton_client.start_stream(callback=seq_callback) |
| 3530 | + |
| 3531 | + # Send N sequences concurrently |
| 3532 | + threads = [] |
| 3533 | + num_sequences = 10 |
| 3534 | + for i in range(num_sequences): |
| 3535 | + # Sequence IDs are 1-indexed |
| 3536 | + sequence_id = i + 1 |
| 3537 | + # Add a result list and callback for each sequence |
| 3538 | + sequence_dict[sequence_id] = [] |
| 3539 | + threads.append( |
| 3540 | + threading.Thread( |
| 3541 | + target=self.send_sequence, |
| 3542 | + args=(sequence_id, sequence_id_map, request_id_map), |
| 3543 | + ) |
| 3544 | + ) |
| 3545 | + |
| 3546 | + # Start all sequence threads |
| 3547 | + for t in threads: |
| 3548 | + t.start() |
| 3549 | + |
| 3550 | + # Wait for threads to return |
| 3551 | + for t in threads: |
| 3552 | + t.join() |
| 3553 | + |
| 3554 | + # Block until all requests are completed |
| 3555 | + self.triton_client.stop_stream() |
| 3556 | + |
| 3557 | + # Make sure some inferences occurred and metadata was collected |
| 3558 | + self.assertGreater(len(sequence_dict), 0) |
| 3559 | + self.assertGreater(len(sequence_list), 0) |
| 3560 | + |
| 3561 | + # Validate model results are sorted per sequence ID (model specific logic) |
| 3562 | + print(f"=== {preserve_ordering=} {decoupled=} ===") |
| 3563 | + print("Outputs per Sequence:") |
| 3564 | + for seq_id, sequence in sequence_dict.items(): |
| 3565 | + seq_outputs = [ |
| 3566 | + result.as_numpy("OUTPUT0").flatten().tolist() for result in sequence |
| 3567 | + ] |
| 3568 | + print(f"{seq_id}: {seq_outputs}") |
| 3569 | + self.assertEqual(seq_outputs, sorted(seq_outputs)) |
| 3570 | + |
| 3571 | + # Validate request/response IDs for each response in a sequence is sorted |
| 3572 | + # This should be true regardless of preserve_ordering or not |
| 3573 | + print("Request IDs per Sequence:") |
| 3574 | + for seq_id in sequence_id_map: |
| 3575 | + per_seq_request_ids = sequence_id_map[seq_id] |
| 3576 | + print(f"{seq_id}: {per_seq_request_ids}") |
| 3577 | + self.assertEqual(per_seq_request_ids, sorted(per_seq_request_ids)) |
| 3578 | + |
| 3579 | + # Validate results are sorted in request order if preserve_ordering is True |
| 3580 | + if preserve_ordering: |
| 3581 | + request_ids = [s.request_id for s in sequence_list] |
| 3582 | + print(f"Request IDs overall:\n{request_ids}") |
| 3583 | + sequence_ids = [s.seq_id for s in sequence_list] |
| 3584 | + print(f"Sequence IDs overall:\n{sequence_ids}") |
| 3585 | + self.assertEqual(request_ids, sorted(request_ids)) |
| 3586 | + |
| 3587 | + # Assert some dynamic batching of requests was done |
| 3588 | + stats = self.triton_client.get_inference_statistics( |
| 3589 | + model_name=self.model_name_, headers={}, as_json=True |
| 3590 | + ) |
| 3591 | + model_stats = stats["model_stats"][0] |
| 3592 | + self.assertEqual(model_stats["name"], self.model_name_) |
| 3593 | + self.assertLess( |
| 3594 | + int(model_stats["execution_count"]), int(model_stats["inference_count"]) |
| 3595 | + ) |
| 3596 | + |
| 3597 | + def test_sequence_with_preserve_ordering(self): |
| 3598 | + self.model_name_ = "seqpy_preserve_ordering_nondecoupled" |
| 3599 | + self._test_sequence_ordering(preserve_ordering=True, decoupled=False) |
| 3600 | + |
| 3601 | + def test_sequence_without_preserve_ordering(self): |
| 3602 | + self.model_name_ = "seqpy_no_preserve_ordering_nondecoupled" |
| 3603 | + self._test_sequence_ordering(preserve_ordering=False, decoupled=False) |
| 3604 | + |
| 3605 | + # FIXME [DLIS-5280]: This may fail for decoupled models if writes to GRPC |
| 3606 | + # stream are done out of order in server, so disable test for now. |
| 3607 | + # def test_sequence_with_preserve_ordering_decoupled(self): |
| 3608 | + # self.model_name_ = "seqpy_preserve_ordering_decoupled" |
| 3609 | + # self._test_sequence_ordering(preserve_ordering=True, decoupled=True) |
| 3610 | + |
| 3611 | + # FIXME [DLIS-5280] |
| 3612 | + # def test_sequence_without_preserve_ordering_decoupled(self): |
| 3613 | + # self.model_name_ = "seqpy_no_preserve_ordering_decoupled" |
| 3614 | + # self._test_sequence_ordering(preserve_ordering=False, decoupled=True) |
| 3615 | + |
| 3616 | + |
3435 | 3617 | if __name__ == "__main__":
|
3436 | 3618 | unittest.main()
|
0 commit comments