@@ -1870,5 +1870,188 @@ def test_model_status_provides_valid_garbage_collection(self):
1870
1870
self .assertEqual (0 , len (tags ))
1871
1871
1872
1872
1873
+ def _always_retry (e : Exception ) -> bool :
1874
+ return True
1875
+
1876
+
1877
+ class FakeRemoteModelHandler (base .RemoteModelHandler [int , int , FakeModel ]):
1878
+ def __init__ (
1879
+ self ,
1880
+ clock = None ,
1881
+ min_batch_size = 1 ,
1882
+ max_batch_size = 9999 ,
1883
+ retry_filter = _always_retry ,
1884
+ ** kwargs ):
1885
+ self ._fake_clock = clock
1886
+ self ._min_batch_size = min_batch_size
1887
+ self ._max_batch_size = max_batch_size
1888
+ self ._env_vars = kwargs .get ('env_vars' , {})
1889
+ self ._multi_process_shared = multi_process_shared
1890
+ super ().__init__ (
1891
+ namespace = 'FakeRemoteModelHandler' , retry_filter = retry_filter )
1892
+
1893
+ def create_client (self ):
1894
+ return FakeModel ()
1895
+
1896
+ def request (self , batch , model , inference_args = None ) -> Iterable [int ]:
1897
+ responses = []
1898
+ for example in batch :
1899
+ responses .append (model .predict (example ))
1900
+ return responses
1901
+
1902
+ def batch_elements_kwargs (self ):
1903
+ return {
1904
+ 'min_batch_size' : self ._min_batch_size ,
1905
+ 'max_batch_size' : self ._max_batch_size
1906
+ }
1907
+
1908
+
1909
+ class FakeAlwaysFailsRemoteModelHandler (base .RemoteModelHandler [int ,
1910
+ int ,
1911
+ FakeModel ]):
1912
+ def __init__ (
1913
+ self ,
1914
+ clock = None ,
1915
+ min_batch_size = 1 ,
1916
+ max_batch_size = 9999 ,
1917
+ retry_filter = _always_retry ,
1918
+ ** kwargs ):
1919
+ self ._fake_clock = clock
1920
+ self ._min_batch_size = min_batch_size
1921
+ self ._max_batch_size = max_batch_size
1922
+ self ._env_vars = kwargs .get ('env_vars' , {})
1923
+ super ().__init__ (
1924
+ namespace = 'FakeRemoteModelHandler' ,
1925
+ retry_filter = retry_filter ,
1926
+ num_retries = 2 ,
1927
+ throttle_delay_secs = 1 )
1928
+
1929
+ def create_client (self ):
1930
+ return FakeModel ()
1931
+
1932
+ def request (self , batch , model , inference_args = None ) -> Iterable [int ]:
1933
+ raise Exception
1934
+
1935
+ def batch_elements_kwargs (self ):
1936
+ return {
1937
+ 'min_batch_size' : self ._min_batch_size ,
1938
+ 'max_batch_size' : self ._max_batch_size
1939
+ }
1940
+
1941
+
1942
+ class FakeFailsOnceRemoteModelHandler (base .RemoteModelHandler [int ,
1943
+ int ,
1944
+ FakeModel ]):
1945
+ def __init__ (
1946
+ self ,
1947
+ clock = None ,
1948
+ min_batch_size = 1 ,
1949
+ max_batch_size = 9999 ,
1950
+ retry_filter = _always_retry ,
1951
+ ** kwargs ):
1952
+ self ._fake_clock = clock
1953
+ self ._min_batch_size = min_batch_size
1954
+ self ._max_batch_size = max_batch_size
1955
+ self ._env_vars = kwargs .get ('env_vars' , {})
1956
+ self ._should_fail = True
1957
+ super ().__init__ (
1958
+ namespace = 'FakeRemoteModelHandler' ,
1959
+ retry_filter = retry_filter ,
1960
+ num_retries = 2 ,
1961
+ throttle_delay_secs = 1 )
1962
+
1963
+ def create_client (self ):
1964
+ return FakeModel ()
1965
+
1966
+ def request (self , batch , model , inference_args = None ) -> Iterable [int ]:
1967
+ if self ._should_fail :
1968
+ self ._should_fail = False
1969
+ raise Exception
1970
+ else :
1971
+ self ._should_fail = True
1972
+ responses = []
1973
+ for example in batch :
1974
+ responses .append (model .predict (example ))
1975
+ return responses
1976
+
1977
+ def batch_elements_kwargs (self ):
1978
+ return {
1979
+ 'min_batch_size' : self ._min_batch_size ,
1980
+ 'max_batch_size' : self ._max_batch_size
1981
+ }
1982
+
1983
+
1984
+ class RunInferenceRemoteTest (unittest .TestCase ):
1985
+ def test_normal_model_execution (self ):
1986
+ with TestPipeline () as pipeline :
1987
+ examples = [1 , 5 , 3 , 10 ]
1988
+ expected = [example + 1 for example in examples ]
1989
+ pcoll = pipeline | 'start' >> beam .Create (examples )
1990
+ actual = pcoll | base .RunInference (FakeRemoteModelHandler ())
1991
+ assert_that (actual , equal_to (expected ), label = 'assert:inferences' )
1992
+
1993
+ def test_repeated_requests_fail (self ):
1994
+ test_pipeline = TestPipeline ()
1995
+ with self .assertRaises (Exception ):
1996
+ _ = (
1997
+ test_pipeline
1998
+ | beam .Create ([1 , 2 , 3 , 4 ])
1999
+ | base .RunInference (FakeAlwaysFailsRemoteModelHandler ()))
2000
+ test_pipeline .run ()
2001
+
2002
+ def test_works_on_retry (self ):
2003
+ with TestPipeline () as pipeline :
2004
+ examples = [1 , 5 , 3 , 10 ]
2005
+ expected = [example + 1 for example in examples ]
2006
+ pcoll = pipeline | 'start' >> beam .Create (examples )
2007
+ actual = pcoll | base .RunInference (FakeFailsOnceRemoteModelHandler ())
2008
+ assert_that (actual , equal_to (expected ), label = 'assert:inferences' )
2009
+
2010
+ def test_exception_on_load_model_override (self ):
2011
+ with self .assertRaises (Exception ):
2012
+
2013
+ class _ (base .RemoteModelHandler [int , int , FakeModel ]):
2014
+ def __init__ (self , clock = None , retry_filter = _always_retry , ** kwargs ):
2015
+ self ._fake_clock = clock
2016
+ self ._min_batch_size = 1
2017
+ self ._max_batch_size = 1
2018
+ self ._env_vars = kwargs .get ('env_vars' , {})
2019
+ super ().__init__ (
2020
+ namespace = 'FakeRemoteModelHandler' , retry_filter = retry_filter )
2021
+
2022
+ def load_model (self ):
2023
+ return FakeModel ()
2024
+
2025
+ def request (self , batch , model , inference_args = None ) -> Iterable [int ]:
2026
+ responses = []
2027
+ for example in batch :
2028
+ responses .append (model .predict (example ))
2029
+ return responses
2030
+
2031
+ def test_exception_on_run_inference_override (self ):
2032
+ with self .assertRaises (Exception ):
2033
+
2034
+ class _ (base .RemoteModelHandler [int , int , FakeModel ]):
2035
+ def __init__ (self , clock = None , retry_filter = _always_retry , ** kwargs ):
2036
+ self ._fake_clock = clock
2037
+ self ._min_batch_size = 1
2038
+ self ._max_batch_size = 1
2039
+ self ._env_vars = kwargs .get ('env_vars' , {})
2040
+ super ().__init__ (
2041
+ namespace = 'FakeRemoteModelHandler' , retry_filter = retry_filter )
2042
+
2043
+ def create_client (self ):
2044
+ return FakeModel ()
2045
+
2046
+ def run_inference (self ,
2047
+ batch ,
2048
+ model ,
2049
+ inference_args = None ) -> Iterable [int ]:
2050
+ responses = []
2051
+ for example in batch :
2052
+ responses .append (model .predict (example ))
2053
+ return responses
2054
+
2055
+
1873
2056
if __name__ == '__main__' :
1874
2057
unittest .main ()
0 commit comments