11import asyncio
22import logging
3- import os
4- from unittest import TestCase , skipIf
3+ from unittest import TestCase
54
65from mqlog import MqttHandler
6+ from tests .utils import AsyncMock , Mock , call
77
88
99class FakeClient :
1010 """A fake MQTT client for testing purposes."""
1111
12- def __init__ (self ):
13- self .calls = []
14-
15- # Store the call like a mock so we can check it later
16- async def publish (self , topic , payload , qos = 0 ):
17- self .calls .append ((topic , payload , qos ))
12+ publish : None
1813
1914
2015class TestMqttHandler (TestCase ):
2116 def setUp (self ):
2217 self .client = FakeClient ()
18+ self .client .publish = AsyncMock ()
2319 self .handler = MqttHandler (self .client , "test_topic" )
2420 self .handler .setFormatter (logging .Formatter ("%(message)s" ))
2521 self .logger = logging .getLogger ("test" )
@@ -50,32 +46,46 @@ def test_flush_full_buffer(self):
5046 self .logger .info (f"Test message { i } " )
5147 self .assertTrue (self .handler .will_flush .is_set ())
5248
53- @skipIf (os .getenv ("CI" ), "Hangs in CI" )
49+ def test_flush_fail (self ):
50+ """Flushing should log an error if publish fails"""
51+ self .handler ._logger .error = Mock ()
52+ self .client .publish = AsyncMock (side_effect = Exception ("Publish failed" ))
53+
54+ async def do_test (handler : logging .Handler , logger : logging .Logger ):
55+ asyncio .create_task (handler .run ())
56+ await asyncio .sleep (0.1 ) # Allow handler to start
57+ logger .error ("Test message" ) # should trigger flush
58+ await asyncio .sleep (0.1 )
59+
60+ asyncio .run (do_test (self .handler , self .logger ))
61+
62+ self .handler ._logger .error .assert_called ()
63+
5464 def test_publish (self ):
5565 """Flushing should publish messages to MQTT topic"""
5666 self .handler .flush_level = logging .ERROR
5767
58- async def do_test (logger ):
68+ async def do_test (handler : logging .Handler , logger : logging .Logger ):
69+ asyncio .create_task (handler .run ())
70+ await asyncio .sleep (0.1 )
5971 logger .info ("Test message 1" )
6072 await asyncio .sleep (0.1 )
6173 logger .error ("Test message 2" )
6274 await asyncio .sleep (0.1 )
6375
64- async def main (handler , logger ):
65- await asyncio .gather (handler .run (), do_test (logger ))
66-
67- asyncio .run (main (self .handler , self .logger ))
76+ asyncio .run (do_test (self .handler , self .logger ))
6877
69- self .assertEqual (
70- self . client . calls , [( "test_topic" , "Test message 1\n Test message 2" , 0 )]
78+ self .client . publish . assert_called_with (
79+ "test_topic" , "Test message 1\n Test message 2" , qos = 0
7180 )
7281
73- @skipIf (os .getenv ("CI" ), "Hangs in CI" )
7482 def test_flush_multiple (self ):
7583 """Flushing multiple times should publish separate messages"""
7684 self .handler .capacity = 2
7785
78- async def do_test (logger ):
86+ async def do_test (handler : logging .Handler , logger : logging .Logger ):
87+ asyncio .create_task (handler .run ())
88+ await asyncio .sleep (0.1 )
7989 logger .info ("Test message 1" )
8090 await asyncio .sleep (0.1 )
8191 logger .info ("Test message 2" )
@@ -85,15 +95,11 @@ async def do_test(logger):
8595 logger .info ("Test message 4" )
8696 await asyncio .sleep (0.1 )
8797
88- async def main (handler , logger ):
89- await asyncio .gather (handler .run (), do_test (logger ))
90-
91- asyncio .run (main (self .handler , self .logger ))
98+ asyncio .run (do_test (self .handler , self .logger ))
9299
93- self .assertEqual (
94- self .client .calls ,
100+ self .client .publish .assert_has_calls (
95101 [
96- ("test_topic" , "Test message 1\n Test message 2" , 0 ),
97- ("test_topic" , "Test message 3\n Test message 4" , 0 ),
98- ],
102+ call ("test_topic" , "Test message 1\n Test message 2" , qos = 0 ),
103+ call ("test_topic" , "Test message 3\n Test message 4" , qos = 0 ),
104+ ]
99105 )
0 commit comments