Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
*~
.*.sw[nop]
.idea
.DS_Store
__pycache__
build/
dist/

6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ If you are running the client outside of Google Cloud, you must set following en

.. code-block:: python

from google.cloud.dataproc_spark_connect import DataprocSparkSession
from google.cloud.spark_connect import GoogleSparkSession

3. There are two ways to create a spark session,

1. Start a Spark session using properties defined in `DATAPROC_SPARK_CONNECT_SESSION_DEFAULT_CONFIG`:

.. code-block:: python

spark = DataprocSparkSession.builder.getOrCreate()
spark = GoogleSparkSession.builder.getOrCreate()

2. Start a Spark session with the following code instead of using a config file:

Expand All @@ -59,7 +59,7 @@ If you are running the client outside of Google Cloud, you must set following en
dataproc_config.spark_connect_session = SparkConnectConfig()
dataproc_config.environment_config.execution_config.subnetwork_uri = "<subnet>"
dataproc_config.runtime_config.version = '3.0'
spark = DataprocSparkSession.builder.dataprocConfig(dataproc_config).getOrCreate()
spark = GoogleSparkSession.builder.dataprocConfig(dataproc_config).getOrCreate()

## Billing
As this client runs the spark workload on Dataproc, your project will be billed as per [Dataproc Serverless Pricing](https://cloud.google.com/dataproc-serverless/pricing).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .session import DataprocSparkSession
from .session import GoogleSparkSession
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from google.api_core.exceptions import FailedPrecondition, InvalidArgument, NotFound
from google.cloud.dataproc_v1.types import sessions

from google.cloud.dataproc_spark_connect.client import DataprocChannelBuilder
from google.cloud.spark_connect.client import DataprocChannelBuilder
from google.cloud.dataproc_v1 import (
CreateSessionRequest,
GetSessionRequest,
Expand All @@ -47,7 +47,7 @@
logger = logging.getLogger(__name__)


class DataprocSparkSession(SparkSession):
class GoogleSparkSession(SparkSession):
"""The entry point to programming Spark with the Dataset and DataFrame API.

A DataprocRemoteSparkSession can be used to create :class:`DataFrame`, register :class:`DataFrame` as
Expand All @@ -59,7 +59,7 @@ class DataprocSparkSession(SparkSession):
Create a Spark session with Dataproc Spark Connect.

>>> spark = (
... DataprocSparkSession.builder
... GoogleSparkSession.builder
... .appName("Word Count")
... .dataprocConfig(Session())
... .getOrCreate()
Expand Down Expand Up @@ -130,25 +130,25 @@ def dataprocConfig(self, dataproc_config: Session):
def remote(self, url: Optional[str] = None) -> "SparkSession.Builder":
if url:
raise NotImplemented(
"DataprocSparkSession does not support connecting to an existing remote server"
"GoogleSparkSession does not support connecting to an existing remote server"
)
else:
return self

def create(self) -> "SparkSession":
raise NotImplemented(
"DataprocSparkSession allows session creation only through getOrCreate"
"GoogleSparkSession allows session creation only through getOrCreate"
)

def __create_spark_connect_session_from_s8s(
self, session_response
) -> "SparkSession":
DataprocSparkSession._active_s8s_session_uuid = (
GoogleSparkSession._active_s8s_session_uuid = (
session_response.uuid
)
DataprocSparkSession._project_id = self._project_id
DataprocSparkSession._region = self._region
DataprocSparkSession._client_options = self._client_options
GoogleSparkSession._project_id = self._project_id
GoogleSparkSession._region = self._region
GoogleSparkSession._client_options = self._client_options
spark_connect_url = session_response.runtime_info.endpoints.get(
"Spark Connect Server"
)
Expand All @@ -158,9 +158,9 @@ def __create_spark_connect_session_from_s8s(
self._channel_builder = DataprocChannelBuilder(url)

assert self._channel_builder is not None
session = DataprocSparkSession(connection=self._channel_builder)
session = GoogleSparkSession(connection=self._channel_builder)

DataprocSparkSession._set_default_and_active_session(session)
GoogleSparkSession._set_default_and_active_session(session)
self.__apply_options(session)
return session

Expand All @@ -169,7 +169,7 @@ def __create(self) -> "SparkSession":

if self._options.get("spark.remote", False):
raise NotImplemented(
"DataprocSparkSession does not support connecting to an existing remote server"
"GoogleSparkSession does not support connecting to an existing remote server"
)

from google.cloud.dataproc_v1 import SessionControllerClient
Expand Down Expand Up @@ -202,7 +202,7 @@ def __create(self) -> "SparkSession":
)

logger.debug("Creating serverless session")
DataprocSparkSession._active_s8s_session_id = session_id
GoogleSparkSession._active_s8s_session_id = session_id
s8s_creation_start_time = time.time()
try:
session_polling = retry.Retry(
Expand Down Expand Up @@ -246,12 +246,12 @@ def __create(self) -> "SparkSession":
f"Exception while writing active session to file {file_path} , {e}"
)
except InvalidArgument as e:
DataprocSparkSession._active_s8s_session_id = None
GoogleSparkSession._active_s8s_session_id = None
raise RuntimeError(
f"Error while creating serverless session: {e}"
) from None
except Exception as e:
DataprocSparkSession._active_s8s_session_id = None
GoogleSparkSession._active_s8s_session_id = None
raise RuntimeError(
f"Error while creating serverless session https://console.cloud.google.com/dataproc/interactive/{self._region}/{session_id} : {e}"
) from None
Expand Down Expand Up @@ -286,12 +286,12 @@ def _is_s8s_session_active(
return None

def _get_exiting_active_session(self) -> Optional["SparkSession"]:
s8s_session_id = DataprocSparkSession._active_s8s_session_id
s8s_session_id = GoogleSparkSession._active_s8s_session_id
session_response = self._is_s8s_session_active(s8s_session_id)

session = DataprocSparkSession.getActiveSession()
session = GoogleSparkSession.getActiveSession()
if session is None:
session = DataprocSparkSession._default_session
session = GoogleSparkSession._default_session

if session_response is not None:
print(
Expand All @@ -312,7 +312,7 @@ def _get_exiting_active_session(self) -> Optional["SparkSession"]:
return None

def getOrCreate(self) -> "SparkSession":
with DataprocSparkSession._lock:
with GoogleSparkSession._lock:
session = self._get_exiting_active_session()
if session is None:
session = self.__create()
Expand Down Expand Up @@ -413,7 +413,7 @@ def _get_and_validate_version(self, dataproc_config, session_template):
"dataproc-spark-connect"
)
client_version = importlib.metadata.version("pyspark")
version_message = f"Dataproc Spark Connect: {dataproc_connect_version} (PySpark: {client_version}) Dataproc Session Runtime: {version} (Spark: {server_version})"
version_message = f"Spark Connect: {dataproc_connect_version} (PySpark: {client_version}) Session Runtime: {version} (Spark: {server_version})"
logger.info(version_message)
if trimmed_version(client_version) != trimmed_version(
server_version
Expand Down Expand Up @@ -454,7 +454,7 @@ def _repr_html_(self) -> str:
<div>
<p><b>Spark Connect</b></p>

<p><a href="{s8s_session}">Dataproc Session</a></p>
<p><a href="{s8s_session}">Serverless Session</a></p>
<p><a href="{ui}">Spark UI</a></p>
</div>
"""
Expand All @@ -473,15 +473,15 @@ def _remove_stoped_session_from_file(self):
)

def stop(self) -> None:
with DataprocSparkSession._lock:
if DataprocSparkSession._active_s8s_session_id is not None:
with GoogleSparkSession._lock:
if GoogleSparkSession._active_s8s_session_id is not None:
from google.cloud.dataproc_v1 import SessionControllerClient

logger.debug(
f"Terminating serverless session: {DataprocSparkSession._active_s8s_session_id}"
f"Terminating serverless session: {GoogleSparkSession._active_s8s_session_id}"
)
terminate_session_request = TerminateSessionRequest()
session_name = f"projects/{DataprocSparkSession._project_id}/locations/{DataprocSparkSession._region}/sessions/{DataprocSparkSession._active_s8s_session_id}"
session_name = f"projects/{GoogleSparkSession._project_id}/locations/{GoogleSparkSession._region}/sessions/{GoogleSparkSession._active_s8s_session_id}"
terminate_session_request.name = session_name
state = None
try:
Expand All @@ -503,26 +503,26 @@ def stop(self) -> None:
sleep(1)
except NotFound:
logger.debug(
f"Session {DataprocSparkSession._active_s8s_session_id} already deleted"
f"Session {GoogleSparkSession._active_s8s_session_id} already deleted"
)
except FailedPrecondition:
logger.debug(
f"Session {DataprocSparkSession._active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits"
f"Session {GoogleSparkSession._active_s8s_session_id} already terminated manually or terminated automatically through session ttl limits"
)
if state is not None and state == Session.State.FAILED:
raise RuntimeError("Serverless session termination failed")

self._remove_stoped_session_from_file()
DataprocSparkSession._active_s8s_session_uuid = None
DataprocSparkSession._active_s8s_session_id = None
DataprocSparkSession._project_id = None
DataprocSparkSession._region = None
DataprocSparkSession._client_options = None
GoogleSparkSession._active_s8s_session_uuid = None
GoogleSparkSession._active_s8s_session_id = None
GoogleSparkSession._project_id = None
GoogleSparkSession._region = None
GoogleSparkSession._client_options = None

self.client.close()
if self is DataprocSparkSession._default_session:
DataprocSparkSession._default_session = None
if self is GoogleSparkSession._default_session:
GoogleSparkSession._default_session = None
if self is getattr(
DataprocSparkSession._active_session, "session", None
GoogleSparkSession._active_session, "session", None
):
DataprocSparkSession._active_session.session = None
GoogleSparkSession._active_session.session = None
Loading
Loading