44from unittest import TestCase , skipIf
55
66from mqlog import MqttHandler
7+ from tests .utils import AsyncMock , Mock , call
78
89
910class FakeClient :
1011 """A fake MQTT client for testing purposes."""
1112
12- def __init__ (self ):
13- self .calls = []
14-
15- # Store the call like a mock so we can check it later
16- async def publish (self , topic , payload , qos = 0 ):
17- self .calls .append ((topic , payload , qos ))
13+ publish : None
1814
1915
2016class TestMqttHandler (TestCase ):
2117 def setUp (self ):
2218 self .client = FakeClient ()
19+ self .client .publish = AsyncMock ()
2320 self .handler = MqttHandler (self .client , "test_topic" )
2421 self .handler .setFormatter (logging .Formatter ("%(message)s" ))
2522 self .logger = logging .getLogger ("test" )
@@ -50,32 +47,46 @@ def test_flush_full_buffer(self):
5047 self .logger .info (f"Test message { i } " )
5148 self .assertTrue (self .handler .will_flush .is_set ())
5249
53- @skipIf (os .getenv ("CI" ), "Hangs in CI" )
50+ def test_flush_fail (self ):
51+ """Flushing should log an error if publish fails"""
52+ self .handler ._logger .error = Mock ()
53+ self .client .publish = AsyncMock (side_effect = Exception ("Publish failed" ))
54+
55+ async def do_test (handler : logging .Handler , logger : logging .Logger ):
56+ asyncio .create_task (handler .run ())
57+ await asyncio .sleep (0.1 ) # Allow handler to start
58+ logger .error ("Test message" ) # should trigger flush
59+ await asyncio .sleep (0.1 )
60+
61+ asyncio .run (do_test (self .handler , self .logger ))
62+
63+ self .handler ._logger .error .assert_called ()
64+
5465 def test_publish (self ):
5566 """Flushing should publish messages to MQTT topic"""
5667 self .handler .flush_level = logging .ERROR
5768
58- async def do_test (logger ):
69+ async def do_test (handler : logging .Handler , logger : logging .Logger ):
70+ asyncio .create_task (handler .run ())
71+ await asyncio .sleep (0.1 )
5972 logger .info ("Test message 1" )
6073 await asyncio .sleep (0.1 )
6174 logger .error ("Test message 2" )
6275 await asyncio .sleep (0.1 )
6376
64- async def main (handler , logger ):
65- await asyncio .gather (handler .run (), do_test (logger ))
66-
67- asyncio .run (main (self .handler , self .logger ))
77+ asyncio .run (do_test (self .handler , self .logger ))
6878
69- self .assertEqual (
70- self . client . calls , [( "test_topic" , "Test message 1\n Test message 2" , 0 )]
79+ self .client . publish . assert_called_with (
80+ "test_topic" , "Test message 1\n Test message 2" , qos = 0
7181 )
7282
73- @skipIf (os .getenv ("CI" ), "Hangs in CI" )
7483 def test_flush_multiple (self ):
7584 """Flushing multiple times should publish separate messages"""
7685 self .handler .capacity = 2
7786
78- async def do_test (logger ):
87+ async def do_test (handler : logging .Handler , logger : logging .Logger ):
88+ asyncio .create_task (handler .run ())
89+ await asyncio .sleep (0.1 )
7990 logger .info ("Test message 1" )
8091 await asyncio .sleep (0.1 )
8192 logger .info ("Test message 2" )
@@ -85,15 +96,11 @@ async def do_test(logger):
8596 logger .info ("Test message 4" )
8697 await asyncio .sleep (0.1 )
8798
88- async def main (handler , logger ):
89- await asyncio .gather (handler .run (), do_test (logger ))
90-
91- asyncio .run (main (self .handler , self .logger ))
99+ asyncio .run (do_test (self .handler , self .logger ))
92100
93- self .assertEqual (
94- self .client .calls ,
101+ self .client .publish .assert_has_calls (
95102 [
96- ("test_topic" , "Test message 1\n Test message 2" , 0 ),
97- ("test_topic" , "Test message 3\n Test message 4" , 0 ),
98- ],
103+ call ("test_topic" , "Test message 1\n Test message 2" , qos = 0 ),
104+ call ("test_topic" , "Test message 3\n Test message 4" , qos = 0 ),
105+ ]
99106 )
0 commit comments