Skip to content

Commit 8883caa

Browse files
committed
fix: terminate s8s session on kernel termination
1 parent 61e7b47 commit 8883caa

File tree

1 file changed

+56
-35
lines changed

1 file changed

+56
-35
lines changed

google/cloud/spark_connect/session.py

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import atexit
1415
import json
1516
import logging
1617
import os
@@ -213,6 +214,16 @@ def __create(self) -> "SparkSession":
213214
logger.info(
214215
"Creating Spark session. It may take few minutes."
215216
)
217+
atexit.register(
218+
atexit.register(
219+
lambda: ServerlessSessionHelper.terminate_s8s_session(
220+
self._project_id,
221+
self._region,
222+
session_id,
223+
self._client_options,
224+
)
225+
)
226+
)
216227
operation = SessionControllerClient(
217228
client_options=self._client_options
218229
).create_session(session_request)
@@ -473,42 +484,12 @@ def _remove_stoped_session_from_file(self):
473484
def stop(self) -> None:
474485
with GoogleSparkSession._lock:
475486
if GoogleSparkSession._active_s8s_session_id is not None:
476-
from google.cloud.dataproc_v1 import SessionControllerClient
477-
478-
logger.debug(
479-
f"Terminating serverless session: {GoogleSparkSession._active_s8s_session_id}"
487+
ServerlessSessionHelper.terminate_s8s_session(
488+
GoogleSparkSession._project_id,
489+
GoogleSparkSession._region,
490+
GoogleSparkSession._active_s8s_session_id,
491+
self._client_options,
480492
)
481-
terminate_session_request = TerminateSessionRequest()
482-
session_name = f"projects/{GoogleSparkSession._project_id}/locations/{GoogleSparkSession._region}/sessions/{GoogleSparkSession._active_s8s_session_id}"
483-
terminate_session_request.name = session_name
484-
state = None
485-
try:
486-
SessionControllerClient(
487-
client_options=self._client_options
488-
).terminate_session(terminate_session_request)
489-
get_session_request = GetSessionRequest()
490-
get_session_request.name = session_name
491-
state = Session.State.ACTIVE
492-
while (
493-
state != Session.State.TERMINATING
494-
and state != Session.State.TERMINATED
495-
and state != Session.State.FAILED
496-
):
497-
session = SessionControllerClient(
498-
client_options=self._client_options
499-
).get_session(get_session_request)
500-
state = session.state
501-
sleep(1)
502-
except NotFound:
503-
logger.debug(
504-
f"Session {GoogleSparkSession._active_s8s_session_id} already deleted"
505-
)
506-
except FailedPrecondition:
507-
logger.debug(
508-
f"Session {GoogleSparkSession._active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits"
509-
)
510-
if state is not None and state == Session.State.FAILED:
511-
raise RuntimeError("Serverless session termination failed")
512493

513494
self._remove_stoped_session_from_file()
514495
GoogleSparkSession._active_s8s_session_uuid = None
@@ -524,3 +505,43 @@ def stop(self) -> None:
524505
GoogleSparkSession._active_session, "session", None
525506
):
526507
GoogleSparkSession._active_session.session = None
508+
509+
510+
class ServerlessSessionHelper:
511+
512+
@staticmethod
513+
def terminate_s8s_session(
514+
project_id, region, active_s8s_session_id, client_options=None
515+
):
516+
from google.cloud.dataproc_v1 import SessionControllerClient
517+
518+
logger.debug(f"Terminating serverless session: {active_s8s_session_id}")
519+
terminate_session_request = TerminateSessionRequest()
520+
session_name = f"projects/{project_id}/locations/{region}/sessions/{active_s8s_session_id}"
521+
terminate_session_request.name = session_name
522+
state = None
523+
try:
524+
SessionControllerClient(
525+
client_options=client_options
526+
).terminate_session(terminate_session_request)
527+
get_session_request = GetSessionRequest()
528+
get_session_request.name = session_name
529+
state = Session.State.ACTIVE
530+
while (
531+
state != Session.State.TERMINATING
532+
and state != Session.State.TERMINATED
533+
and state != Session.State.FAILED
534+
):
535+
session = SessionControllerClient(
536+
client_options=client_options
537+
).get_session(get_session_request)
538+
state = session.state
539+
sleep(1)
540+
except NotFound:
541+
logger.debug(f"Session {active_s8s_session_id} already deleted")
542+
except FailedPrecondition:
543+
logger.debug(
544+
f"Session {active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits"
545+
)
546+
if state is not None and state == Session.State.FAILED:
547+
raise RuntimeError("Serverless session termination failed")

0 commit comments

Comments
 (0)