1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import atexit
1415import json
1516import logging
1617import os
@@ -213,6 +214,16 @@ def __create(self) -> "SparkSession":
213214 logger .info (
214215 "Creating Spark session. It may take few minutes."
215216 )
217+ atexit .register (
218+ atexit .register (
219+ lambda : ServerlessSessionHelper .terminate_s8s_session (
220+ self ._project_id ,
221+ self ._region ,
222+ session_id ,
223+ self ._client_options ,
224+ )
225+ )
226+ )
216227 operation = SessionControllerClient (
217228 client_options = self ._client_options
218229 ).create_session (session_request )
@@ -473,42 +484,12 @@ def _remove_stoped_session_from_file(self):
473484 def stop (self ) -> None :
474485 with GoogleSparkSession ._lock :
475486 if GoogleSparkSession ._active_s8s_session_id is not None :
476- from google .cloud .dataproc_v1 import SessionControllerClient
477-
478- logger .debug (
479- f"Terminating serverless session: { GoogleSparkSession ._active_s8s_session_id } "
487+ ServerlessSessionHelper .terminate_s8s_session (
488+ GoogleSparkSession ._project_id ,
489+ GoogleSparkSession ._region ,
490+ GoogleSparkSession ._active_s8s_session_id ,
491+ self ._client_options ,
480492 )
481- terminate_session_request = TerminateSessionRequest ()
482- session_name = f"projects/{ GoogleSparkSession ._project_id } /locations/{ GoogleSparkSession ._region } /sessions/{ GoogleSparkSession ._active_s8s_session_id } "
483- terminate_session_request .name = session_name
484- state = None
485- try :
486- SessionControllerClient (
487- client_options = self ._client_options
488- ).terminate_session (terminate_session_request )
489- get_session_request = GetSessionRequest ()
490- get_session_request .name = session_name
491- state = Session .State .ACTIVE
492- while (
493- state != Session .State .TERMINATING
494- and state != Session .State .TERMINATED
495- and state != Session .State .FAILED
496- ):
497- session = SessionControllerClient (
498- client_options = self ._client_options
499- ).get_session (get_session_request )
500- state = session .state
501- sleep (1 )
502- except NotFound :
503- logger .debug (
504- f"Session { GoogleSparkSession ._active_s8s_session_id } already deleted"
505- )
506- except FailedPrecondition :
507- logger .debug (
508- f"Session { GoogleSparkSession ._active_s8s_session_id } already terminated manually or terminated automatically through session ttl limits"
509- )
510- if state is not None and state == Session .State .FAILED :
511- raise RuntimeError ("Serverless session termination failed" )
512493
513494 self ._remove_stoped_session_from_file ()
514495 GoogleSparkSession ._active_s8s_session_uuid = None
@@ -524,3 +505,43 @@ def stop(self) -> None:
524505 GoogleSparkSession ._active_session , "session" , None
525506 ):
526507 GoogleSparkSession ._active_session .session = None
508+
509+
510+ class ServerlessSessionHelper :
511+
512+ @staticmethod
513+ def terminate_s8s_session (
514+ project_id , region , active_s8s_session_id , client_options = None
515+ ):
516+ from google .cloud .dataproc_v1 import SessionControllerClient
517+
518+ logger .debug (f"Terminating serverless session: { active_s8s_session_id } " )
519+ terminate_session_request = TerminateSessionRequest ()
520+ session_name = f"projects/{ project_id } /locations/{ region } /sessions/{ active_s8s_session_id } "
521+ terminate_session_request .name = session_name
522+ state = None
523+ try :
524+ SessionControllerClient (
525+ client_options = client_options
526+ ).terminate_session (terminate_session_request )
527+ get_session_request = GetSessionRequest ()
528+ get_session_request .name = session_name
529+ state = Session .State .ACTIVE
530+ while (
531+ state != Session .State .TERMINATING
532+ and state != Session .State .TERMINATED
533+ and state != Session .State .FAILED
534+ ):
535+ session = SessionControllerClient (
536+ client_options = client_options
537+ ).get_session (get_session_request )
538+ state = session .state
539+ sleep (1 )
540+ except NotFound :
541+ logger .debug (f"Session { active_s8s_session_id } already deleted" )
542+ except FailedPrecondition :
543+ logger .debug (
544+ f"Session { active_s8s_session_id } already terminated manually or terminated automatically through session ttl limits"
545+ )
546+ if state is not None and state == Session .State .FAILED :
547+ raise RuntimeError ("Serverless session termination failed" )
0 commit comments