44from unittest import TestCase , skipIf
55
66from mqlog import MqttHandler
7+ from tests .utils import AsyncMock , Mock , call
78
89
910class FakeClient :
1011 """A fake MQTT client for testing purposes."""
1112
12- def __init__ (self ):
13- self .calls = []
14-
15- # Store the call like a mock so we can check it later
16- async def publish (self , topic , payload , qos = 0 ):
17- self .calls .append ((topic , payload , qos ))
13+ publish : None
1814
1915
2016class TestMqttHandler (TestCase ):
2117 def setUp (self ):
2218 self .client = FakeClient ()
19+ self .client .publish = AsyncMock ()
2320 self .handler = MqttHandler (self .client , "test_topic" )
2421 self .handler .setFormatter (logging .Formatter ("%(message)s" ))
2522 self .logger = logging .getLogger ("test" )
@@ -50,32 +47,48 @@ def test_flush_full_buffer(self):
5047 self .logger .info (f"Test message { i } " )
5148 self .assertTrue (self .handler .will_flush .is_set ())
5249
50+ def test_flush_fail (self ):
51+ """Flushing should log an error if publish fails"""
52+ self .handler ._logger .error = Mock ()
53+ self .client .publish = AsyncMock (side_effect = Exception ("Publish failed" ))
54+
55+ async def do_test (handler : logging .Handler , logger : logging .Logger ):
56+ asyncio .create_task (handler .run ())
57+ await asyncio .sleep (0.1 ) # Allow handler to start
58+ logger .error ("Test message" ) # should trigger flush
59+ await asyncio .sleep (0.1 )
60+
61+ asyncio .run (do_test (self .handler , self .logger ))
62+
63+ self .handler ._logger .error .assert_called ()
64+
5365 @skipIf (os .getenv ("CI" ), "Hangs in CI" )
5466 def test_publish (self ):
5567 """Flushing should publish messages to MQTT topic"""
5668 self .handler .flush_level = logging .ERROR
5769
58- async def do_test (logger ):
70+ async def do_test (handler : logging .Handler , logger : logging .Logger ):
71+ asyncio .create_task (handler .run ())
72+ await asyncio .sleep (0.1 )
5973 logger .info ("Test message 1" )
6074 await asyncio .sleep (0.1 )
6175 logger .error ("Test message 2" )
6276 await asyncio .sleep (0.1 )
6377
64- async def main (handler , logger ):
65- await asyncio .gather (handler .run (), do_test (logger ))
66-
67- asyncio .run (main (self .handler , self .logger ))
78+ asyncio .run (do_test (self .handler , self .logger ))
6879
69- self .assertEqual (
70- self . client . calls , [( "test_topic" , "Test message 1\n Test message 2" , 0 )]
80+ self .client . publish . assert_called_with (
81+ "test_topic" , "Test message 1\n Test message 2" , qos = 0
7182 )
7283
7384 @skipIf (os .getenv ("CI" ), "Hangs in CI" )
7485 def test_flush_multiple (self ):
7586 """Flushing multiple times should publish separate messages"""
7687 self .handler .capacity = 2
7788
78- async def do_test (logger ):
89+ async def do_test (handler : logging .Handler , logger : logging .Logger ):
90+ asyncio .create_task (handler .run ())
91+ await asyncio .sleep (0.1 )
7992 logger .info ("Test message 1" )
8093 await asyncio .sleep (0.1 )
8194 logger .info ("Test message 2" )
@@ -85,15 +98,11 @@ async def do_test(logger):
8598 logger .info ("Test message 4" )
8699 await asyncio .sleep (0.1 )
87100
88- async def main (handler , logger ):
89- await asyncio .gather (handler .run (), do_test (logger ))
90-
91- asyncio .run (main (self .handler , self .logger ))
101+ asyncio .run (do_test (self .handler , self .logger ))
92102
93- self .assertEqual (
94- self .client .calls ,
103+ self .client .publish .assert_has_calls (
95104 [
96- ("test_topic" , "Test message 1\n Test message 2" , 0 ),
97- ("test_topic" , "Test message 3\n Test message 4" , 0 ),
98- ],
105+ call ("test_topic" , "Test message 1\n Test message 2" , qos = 0 ),
106+ call ("test_topic" , "Test message 3\n Test message 4" , qos = 0 ),
107+ ]
99108 )
0 commit comments