Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
*~
.*.sw[nop]
.idea
.DS_Store
__pycache__
build/
dist/

6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ If you are running the client outside of Google Cloud, you must set following en

.. code-block:: python

from google.cloud.dataproc_spark_connect import DataprocSparkSession
from google.cloud.spark_connect import GoogleSparkSession

3. There are two ways to create a spark session,

1. Start a Spark session using properties defined in `DATAPROC_SPARK_CONNECT_SESSION_DEFAULT_CONFIG`:

.. code-block:: python

spark = DataprocSparkSession.builder.getOrCreate()
spark = GoogleSparkSession.builder.getOrCreate()

2. Start a Spark session with the following code instead of using a config file:

Expand All @@ -59,7 +59,7 @@ If you are running the client outside of Google Cloud, you must set following en
dataproc_config.spark_connect_session = SparkConnectConfig()
dataproc_config.environment_config.execution_config.subnetwork_uri = "<subnet>"
dataproc_config.runtime_config.version = '3.0'
spark = DataprocSparkSession.builder.dataprocConfig(dataproc_config).getOrCreate()
spark = GoogleSparkSession.builder.dataprocConfig(dataproc_config).getOrCreate()

## Billing
As this client runs the spark workload on Dataproc, your project will be billed as per [Dataproc Serverless Pricing](https://cloud.google.com/dataproc-serverless/pricing).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .session import DataprocSparkSession
from .session import GoogleSparkSession
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from google.api_core.exceptions import FailedPrecondition, InvalidArgument, NotFound
from google.cloud.dataproc_v1.types import sessions

from google.cloud.dataproc_spark_connect.client import DataprocChannelBuilder
from google.cloud.spark_connect.client import DataprocChannelBuilder
from google.cloud.dataproc_v1 import (
CreateSessionRequest,
GetSessionRequest,
Expand All @@ -47,7 +47,7 @@
logger = logging.getLogger(__name__)


class DataprocSparkSession(SparkSession):
class GoogleSparkSession(SparkSession):
"""The entry point to programming Spark with the Dataset and DataFrame API.

A DataprocRemoteSparkSession can be used to create :class:`DataFrame`, register :class:`DataFrame` as
Expand All @@ -59,7 +59,7 @@ class DataprocSparkSession(SparkSession):
Create a Spark session with Dataproc Spark Connect.

>>> spark = (
... DataprocSparkSession.builder
... GoogleSparkSession.builder
... .appName("Word Count")
... .dataprocConfig(Session())
... .getOrCreate()
Expand Down Expand Up @@ -130,25 +130,23 @@ def dataprocConfig(self, dataproc_config: Session):
def remote(self, url: Optional[str] = None) -> "SparkSession.Builder":
if url:
raise NotImplemented(
"DataprocSparkSession does not support connecting to an existing remote server"
"GoogleSparkSession does not support connecting to an existing remote server"
)
else:
return self

def create(self) -> "SparkSession":
raise NotImplemented(
"DataprocSparkSession allows session creation only through getOrCreate"
"GoogleSparkSession allows session creation only through getOrCreate"
)

def __create_spark_connect_session_from_s8s(
self, session_response
) -> "SparkSession":
DataprocSparkSession._active_s8s_session_uuid = (
session_response.uuid
)
DataprocSparkSession._project_id = self._project_id
DataprocSparkSession._region = self._region
DataprocSparkSession._client_options = self._client_options
GoogleSparkSession._active_s8s_session_uuid = session_response.uuid
GoogleSparkSession._project_id = self._project_id
GoogleSparkSession._region = self._region
GoogleSparkSession._client_options = self._client_options
spark_connect_url = session_response.runtime_info.endpoints.get(
"Spark Connect Server"
)
Expand All @@ -158,9 +156,9 @@ def __create_spark_connect_session_from_s8s(
self._channel_builder = DataprocChannelBuilder(url)

assert self._channel_builder is not None
session = DataprocSparkSession(connection=self._channel_builder)
session = GoogleSparkSession(connection=self._channel_builder)

DataprocSparkSession._set_default_and_active_session(session)
GoogleSparkSession._set_default_and_active_session(session)
self.__apply_options(session)
return session

Expand All @@ -169,7 +167,7 @@ def __create(self) -> "SparkSession":

if self._options.get("spark.remote", False):
raise NotImplemented(
"DataprocSparkSession does not support connecting to an existing remote server"
"GoogleSparkSession does not support connecting to an existing remote server"
)

from google.cloud.dataproc_v1 import SessionControllerClient
Expand Down Expand Up @@ -202,7 +200,7 @@ def __create(self) -> "SparkSession":
)

logger.debug("Creating serverless session")
DataprocSparkSession._active_s8s_session_id = session_id
GoogleSparkSession._active_s8s_session_id = session_id
s8s_creation_start_time = time.time()
try:
session_polling = retry.Retry(
Expand Down Expand Up @@ -246,12 +244,12 @@ def __create(self) -> "SparkSession":
f"Exception while writing active session to file {file_path} , {e}"
)
except InvalidArgument as e:
DataprocSparkSession._active_s8s_session_id = None
GoogleSparkSession._active_s8s_session_id = None
raise RuntimeError(
f"Error while creating serverless session: {e}"
) from None
except Exception as e:
DataprocSparkSession._active_s8s_session_id = None
GoogleSparkSession._active_s8s_session_id = None
raise RuntimeError(
f"Error while creating serverless session https://console.cloud.google.com/dataproc/interactive/{self._region}/{session_id} : {e}"
) from None
Expand Down Expand Up @@ -286,12 +284,12 @@ def _is_s8s_session_active(
return None

def _get_exiting_active_session(self) -> Optional["SparkSession"]:
s8s_session_id = DataprocSparkSession._active_s8s_session_id
s8s_session_id = GoogleSparkSession._active_s8s_session_id
session_response = self._is_s8s_session_active(s8s_session_id)

session = DataprocSparkSession.getActiveSession()
session = GoogleSparkSession.getActiveSession()
if session is None:
session = DataprocSparkSession._default_session
session = GoogleSparkSession._default_session

if session_response is not None:
print(
Expand All @@ -312,7 +310,7 @@ def _get_exiting_active_session(self) -> Optional["SparkSession"]:
return None

def getOrCreate(self) -> "SparkSession":
with DataprocSparkSession._lock:
with GoogleSparkSession._lock:
session = self._get_exiting_active_session()
if session is None:
session = self.__create()
Expand Down Expand Up @@ -413,7 +411,7 @@ def _get_and_validate_version(self, dataproc_config, session_template):
"dataproc-spark-connect"
)
client_version = importlib.metadata.version("pyspark")
version_message = f"Dataproc Spark Connect: {dataproc_connect_version} (PySpark: {client_version}) Dataproc Session Runtime: {version} (Spark: {server_version})"
version_message = f"Spark Connect: {dataproc_connect_version} (PySpark: {client_version}) Session Runtime: {version} (Spark: {server_version})"
logger.info(version_message)
if trimmed_version(client_version) != trimmed_version(
server_version
Expand Down Expand Up @@ -454,7 +452,7 @@ def _repr_html_(self) -> str:
<div>
<p><b>Spark Connect</b></p>

<p><a href="{s8s_session}">Dataproc Session</a></p>
<p><a href="{s8s_session}">Serverless Session</a></p>
<p><a href="{ui}">Spark UI</a></p>
</div>
"""
Expand All @@ -473,15 +471,15 @@ def _remove_stoped_session_from_file(self):
)

def stop(self) -> None:
with DataprocSparkSession._lock:
if DataprocSparkSession._active_s8s_session_id is not None:
with GoogleSparkSession._lock:
if GoogleSparkSession._active_s8s_session_id is not None:
from google.cloud.dataproc_v1 import SessionControllerClient

logger.debug(
f"Terminating serverless session: {DataprocSparkSession._active_s8s_session_id}"
f"Terminating serverless session: {GoogleSparkSession._active_s8s_session_id}"
)
terminate_session_request = TerminateSessionRequest()
session_name = f"projects/{DataprocSparkSession._project_id}/locations/{DataprocSparkSession._region}/sessions/{DataprocSparkSession._active_s8s_session_id}"
session_name = f"projects/{GoogleSparkSession._project_id}/locations/{GoogleSparkSession._region}/sessions/{GoogleSparkSession._active_s8s_session_id}"
terminate_session_request.name = session_name
state = None
try:
Expand All @@ -503,26 +501,26 @@ def stop(self) -> None:
sleep(1)
except NotFound:
logger.debug(
f"Session {DataprocSparkSession._active_s8s_session_id} already deleted"
f"Session {GoogleSparkSession._active_s8s_session_id} already deleted"
)
except FailedPrecondition:
logger.debug(
f"Session {DataprocSparkSession._active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits"
f"Session {GoogleSparkSession._active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits"
)
if state is not None and state == Session.State.FAILED:
raise RuntimeError("Serverless session termination failed")

self._remove_stoped_session_from_file()
DataprocSparkSession._active_s8s_session_uuid = None
DataprocSparkSession._active_s8s_session_id = None
DataprocSparkSession._project_id = None
DataprocSparkSession._region = None
DataprocSparkSession._client_options = None
GoogleSparkSession._active_s8s_session_uuid = None
GoogleSparkSession._active_s8s_session_id = None
GoogleSparkSession._project_id = None
GoogleSparkSession._region = None
GoogleSparkSession._client_options = None

self.client.close()
if self is DataprocSparkSession._default_session:
DataprocSparkSession._default_session = None
if self is GoogleSparkSession._default_session:
GoogleSparkSession._default_session = None
if self is getattr(
DataprocSparkSession._active_session, "session", None
GoogleSparkSession._active_session, "session", None
):
DataprocSparkSession._active_session.session = None
GoogleSparkSession._active_session.session = None
Loading
Loading