1313import signal
1414import subprocess
1515import sys
16+ from pathlib import Path
17+ from urllib .parse import quote_plus , urlencode
1618
17- from charms .data_platform_libs .v0 .data_interfaces import DatabaseRequires
19+ from charms .data_platform_libs .v0 .data_interfaces import DatabaseRequires , DatabaseCreatedEvent
1820from ops .charm import ActionEvent , CharmBase
1921from ops .main import main
2022from ops .model import ActiveStatus , Relation , WaitingStatus
2123from pymongo import MongoClient
24+ from pymongo .uri_parser import parse_uri
2225from tenacity import RetryError , Retrying , stop_after_delay , wait_fixed
2326
2427logger = logging .getLogger (__name__ )
3033PROC_PID_KEY = "proc-pid"
3134LAST_WRITTEN_FILE = "last_written_value"
3235
36+ CA_PATH = Path ("/tmp/ca.crt" )
37+
3338
3439class ContinuousWritesApplication (CharmBase ):
3540 """Application charm that continuously writes to MongoDB."""
@@ -51,13 +56,35 @@ def __init__(self, *args):
5156 )
5257
5358 # Database related events
54- self .database = DatabaseRequires (self , "database" , DATABASE_NAME )
59+ self .database = DatabaseRequires (self , "mongodb" , self .database_name )
60+ # Database related events
61+ self .mongos_database = DatabaseRequires (self , "mongos" , self .database_name , external_node_connectivity = True )
62+
5563 self .framework .observe (self .database .on .database_created , self ._on_database_created )
64+ self .framework .observe (self .mongos_database .on .database_created , self ._on_database_created )
65+
66+ if (data := list (self .database .fetch_relation_data ().values ())):
67+ if (tls_ca := data [0 ].get ("tls-ca" )):
68+ CA_PATH .write_text (tls_ca )
69+ return
70+
71+ if (data := list (self .mongos_database .fetch_relation_data ().values ())):
72+ if (tls_ca := data [0 ].get ("tls-ca" )):
73+ CA_PATH .write_text (tls_ca )
74+ return
75+
76+ if tls_ca := self .model .config .get ("tls-ca" , None ):
77+ CA_PATH .write_text (tls_ca )
78+ return
5679
5780 # ==============
5881 # Properties
5982 # ==============
6083
84+ @property
85+ def database_name (self ) -> str :
86+ return self .model .config .get ("database-name" , DATABASE_NAME )
87+
6188 @property
6289 def _peers (self ) -> Relation | None :
6390 """Retrieve the peer relation (`ops.model.Relation`)."""
@@ -76,33 +103,60 @@ def _database_config(self) -> dict[str, str]:
76103 """Returns the database config to use to connect to the MongoDB cluster."""
77104 # In some tests we want to write directly to mongos, but the config-server does not
78105 # support integrations to client applications, so the data to connect is set via config.
79- if not (data := list (self .database .fetch_relation_data ().values ())):
80- return {"uris" : self .model .config .get ("mongos-uri" , None )}
106+ if not self .database .relations and not self .mongos_database .relations :
107+ uri = self .model .config .get ("mongos-uri" , "" )
108+ if self .model .config .get ("tls-ca" ):
109+ uri = self ._build_tls_uri (uri )
110+
111+ return {"uris" : uri }
112+
113+ if self .database .relations :
114+ data = list (self .database .fetch_relation_data ().values ())[0 ]
115+ elif self .mongos_database .relations :
116+ data = list (self .mongos_database .fetch_relation_data ().values ())[0 ]
117+ else :
118+ return {}
81119
82- data = data [0 ]
83- username , password , endpoints , replset , uris = (
120+ username , password , endpoints , replset , uris , tls = (
84121 data .get ("username" ),
85122 data .get ("password" ),
86123 data .get ("endpoints" ),
87124 data .get ("replset" ),
88125 data .get ("uris" ),
126+ data .get ("tls" )
89127 )
90128
91- if None in [username , password , endpoints , replset , uris ]:
129+ if None in [username , password , endpoints , uris ]:
92130 return {}
93131
132+ if tls :
133+ uris = self ._build_tls_uri (uris )
134+
94135 return {
95136 "user" : username ,
96137 "password" : password ,
97138 "endpoints" : endpoints ,
98- "replset" : replset ,
139+ "replset" : replset or "" ,
99140 "uris" : uris ,
100141 }
101142
102143 # ==============
103144 # Helpers
104145 # ==============
105146
147+ def _build_tls_uri (self , uris : str ) -> str :
148+ parsed_uri = parse_uri (uris )
149+ params = parsed_uri ["options" ]
150+ params ["tls" ] = "true"
151+ params ["tlsCaFile" ] = f"{ CA_PATH } "
152+ hosts = "," .join (f"{ host } :{ port } " for host , port in parsed_uri ["nodelist" ])
153+ return (
154+ f"mongodb://{ quote_plus (parsed_uri ['username' ])} :"
155+ f"{ quote_plus (parsed_uri ['password' ])} @"
156+ f"{ hosts } /{ quote_plus (parsed_uri ['database' ])} ?"
157+ f"{ urlencode (params )} "
158+ )
159+
106160 def _start_continuous_writes (
107161 self , starting_number : int , db_name : str , collection_name : str
108162 ) -> None :
@@ -148,8 +202,9 @@ def _stop_continuous_writes(self, db_name: str, collection_name: str) -> int | N
148202 logger .info (
149203 f"Process { self .proc_id_key (db_name , collection_name )} was killed already (or never existed)"
150204 )
151-
152- del self .app_peer_data [self .proc_id_key (db_name , collection_name )]
205+ return - 1
206+ finally :
207+ del self .app_peer_data [self .proc_id_key (db_name , collection_name )]
153208
154209 # read the last written_value
155210 try :
@@ -170,7 +225,7 @@ def proc_id_key(self, db_name: str, collection_name: str) -> str:
170225 return f"{ PROC_PID_KEY } -{ db_name } -{ collection_name } "
171226
172227 def last_written_filename (self , db_name : str , collection_name : str ) -> str :
173- """Returns a process id key for the continuous writes process to a given db and coll."""
228+ """Returns the filename for the written data for a given db and coll."""
174229 return f"{ LAST_WRITTEN_FILE } -{ db_name } -{ collection_name } "
175230
176231 # ==============
@@ -210,7 +265,7 @@ def _on_start_continuous_writes_action(self, event) -> None:
210265 if not self ._database_config :
211266 return
212267
213- db_name = event .params .get ("db-name" ) or DATABASE_NAME
268+ db_name = event .params .get ("db-name" ) or self . database_name
214269 collection_name = event .params .get ("collection-name" ) or COLLECTION_NAME
215270 self ._start_continuous_writes (1 , db_name , collection_name )
216271
@@ -225,10 +280,12 @@ def _on_stop_continuous_writes_action(self, event: ActionEvent) -> None:
225280 event .set_results ({"writes" : writes or - 1 })
226281 return None
227282
228- def _on_database_created (self , _ ) -> None :
283+ def _on_database_created (self , event : DatabaseCreatedEvent ) -> None :
229284 """Handle the database created event."""
230- self .unit .status = ActiveStatus ()
285+ if event .tls == "True" :
286+ CA_PATH .write_text (event .tls_ca )
231287
288+ self .unit .status = ActiveStatus ()
232289
233290if __name__ == "__main__" :
234291 main (ContinuousWritesApplication )
0 commit comments