Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

### Bug Fixes

* `dbutils.fs` methods now accept `pathlib.Path` arguments in addition to strings ([#1461](https://github.com/databricks/databricks-sdk-py/pull/1461)).

### Documentation

### Breaking Changes
Expand Down
52 changes: 29 additions & 23 deletions databricks/sdk/dbutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import threading
from collections import namedtuple
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union

from .core import ApiClient, Config, DatabricksError
from .mixins import compute as compute_ext
Expand All @@ -17,6 +18,11 @@
_LOG = logging.getLogger("databricks.sdk")


def _as_str(path: Union[str, Path]) -> str:
"""Convert a pathlib.Path to str; leave plain strings untouched."""
return str(path) if isinstance(path, Path) else path


class FileInfo(namedtuple("FileInfo", ["path", "name", "size", "modificationTime"])):
pass

Expand Down Expand Up @@ -45,17 +51,17 @@ def __init__(
self._dbfs = dbfs_ext
self._proxy_factory = proxy_factory

def cp(self, from_: str, to: str, recurse: bool = False) -> bool:
def cp(self, from_: Union[str, Path], to: Union[str, Path], recurse: bool = False) -> bool:
"""Copies a file or directory, possibly across FileSystems"""
self._dbfs.copy(from_, to, recursive=recurse)
self._dbfs.copy(_as_str(from_), _as_str(to), recursive=recurse)
return True

def head(self, file: str, maxBytes: int = 65536) -> str:
def head(self, file: Union[str, Path], maxBytes: int = 65536) -> str:
"""Returns up to the first 'maxBytes' bytes of the given file as a String encoded in UTF-8"""
with self._dbfs.download(file) as f:
with self._dbfs.download(_as_str(file)) as f:
return f.read(maxBytes).decode("utf8")

def ls(self, dir: str) -> List[FileInfo]:
def ls(self, dir: Union[str, Path]) -> List[FileInfo]:
"""Lists the contents of a directory"""
return [
FileInfo(
Expand All @@ -64,34 +70,34 @@ def ls(self, dir: str) -> List[FileInfo]:
f.file_size,
f.modification_time,
)
for f in self._dbfs.list(dir)
for f in self._dbfs.list(_as_str(dir))
]

def mkdirs(self, dir: str) -> bool:
def mkdirs(self, dir: Union[str, Path]) -> bool:
"""Creates the given directory if it does not exist, also creating any necessary parent directories"""
self._dbfs.mkdirs(dir)
self._dbfs.mkdirs(_as_str(dir))
return True

def mv(self, from_: str, to: str, recurse: bool = False) -> bool:
def mv(self, from_: Union[str, Path], to: Union[str, Path], recurse: bool = False) -> bool:
"""Moves a file or directory, possibly across FileSystems"""
self._dbfs.move_(from_, to, recursive=recurse, overwrite=True)
self._dbfs.move_(_as_str(from_), _as_str(to), recursive=recurse, overwrite=True)
return True

def put(self, file: str, contents: str, overwrite: bool = False) -> bool:
def put(self, file: Union[str, Path], contents: str, overwrite: bool = False) -> bool:
"""Writes the given String out to a file, encoded in UTF-8"""
with self._dbfs.open(file, write=True, overwrite=overwrite) as f:
with self._dbfs.open(_as_str(file), write=True, overwrite=overwrite) as f:
f.write(contents.encode("utf8"))
return True

def rm(self, dir: str, recurse: bool = False) -> bool:
def rm(self, dir: Union[str, Path], recurse: bool = False) -> bool:
"""Removes a file or directory"""
self._dbfs.delete(dir, recursive=recurse)
self._dbfs.delete(_as_str(dir), recursive=recurse)
return True

def mount(
self,
source: str,
mount_point: str,
source: Union[str, Path],
mount_point: Union[str, Path],
encryption_type: str = None,
owner: str = None,
extra_configs: Dict[str, str] = None,
Expand All @@ -105,17 +111,17 @@ def mount(
kwargs["owner"] = owner
if extra_configs:
kwargs["extra_configs"] = extra_configs
return fs.mount(source, mount_point, **kwargs)
return fs.mount(_as_str(source), _as_str(mount_point), **kwargs)

def unmount(self, mount_point: str) -> bool:
def unmount(self, mount_point: Union[str, Path]) -> bool:
"""Deletes a DBFS mount point"""
fs = self._proxy_factory("fs")
return fs.unmount(mount_point)
return fs.unmount(_as_str(mount_point))

def updateMount(
self,
source: str,
mount_point: str,
source: Union[str, Path],
mount_point: Union[str, Path],
encryption_type: str = None,
owner: str = None,
extra_configs: Dict[str, str] = None,
Expand All @@ -129,7 +135,7 @@ def updateMount(
kwargs["owner"] = owner
if extra_configs:
kwargs["extra_configs"] = extra_configs
return fs.updateMount(source, mount_point, **kwargs)
return fs.updateMount(_as_str(source), _as_str(mount_point), **kwargs)

def mounts(self) -> List[MountInfo]:
"""Displays information about what is mounted within DBFS"""
Expand Down
95 changes: 95 additions & 0 deletions tests/test_dbutils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import pytest as pytest

from databricks.sdk.dbutils import FileInfo as DBUtilsFileInfo
Expand Down Expand Up @@ -307,6 +309,99 @@ def test_dbutils_credentials_get_service_credential_provider(config, mocker):
dbutils.credentials.getServiceCredentialsProvider("creds")


def test_fs_ls_accepts_path_object(dbutils, mocker):
inner = mocker.patch(
"databricks.sdk.service.files.DbfsAPI.list",
return_value=[FileInfo(path="a/b", file_size=10, modification_time=20)],
)
mocker.patch(
"databricks.sdk.service.files.DbfsAPI.get_status",
side_effect=[
FileInfo(path="a", is_dir=True, file_size=5),
FileInfo(path="a/b", is_dir=False, file_size=5),
],
)

result = dbutils.fs.ls(Path("a"))

inner.assert_called_with("a")
assert len(result) == 1
assert result[0] == DBUtilsFileInfo("a/b", "b", 10, 20)


def test_fs_cp_accepts_path_objects(dbutils, mocker):
inner = mocker.patch("databricks.sdk.mixins.files.DbfsExt.copy")

dbutils.fs.cp(Path("a"), Path("b"), recurse=True)

inner.assert_called_with("a", "b", recursive=True)


def test_fs_head_accepts_path_object(dbutils, mocker):
mocker.patch(
"databricks.sdk.service.files.DbfsAPI.read",
return_value=ReadResponse(data="aGVsbG8=", bytes_read=5),
)
mocker.patch(
"databricks.sdk.service.files.DbfsAPI.get_status",
return_value=FileInfo(path="a", is_dir=False, file_size=5),
)

result = dbutils.fs.head(Path("a"))

assert result == "hello"


def test_fs_mkdirs_accepts_path_object(dbutils, mocker):
inner = mocker.patch("databricks.sdk.service.files.DbfsAPI.mkdirs")

dbutils.fs.mkdirs(Path("a"))

inner.assert_called_with("a")


def test_fs_mv_accepts_path_objects(dbutils, mocker):
inner = mocker.patch("databricks.sdk.mixins.files.DbfsExt.move_")

dbutils.fs.mv(Path("a"), Path("b"))

inner.assert_called_with("a", "b", recursive=False, overwrite=True)


def test_fs_put_accepts_path_object(dbutils, mocker):
class _MockOpen:
_written = None

def __enter__(self):
return self

def __exit__(self, *ignored):
pass

def write(self, contents):
self._written = contents

mock_open = _MockOpen()
inner = mocker.patch("databricks.sdk.mixins.files.DbfsExt.open", return_value=mock_open)

dbutils.fs.put(Path("a"), "b")

inner.assert_called_with("a", overwrite=False, write=True)
assert mock_open._written == b"b"


def test_fs_rm_accepts_path_object(dbutils, mocker):
inner = mocker.patch("databricks.sdk.service.files.DbfsAPI.delete")
mocker.patch(
"databricks.sdk.service.files.DbfsAPI.get_status",
return_value=FileInfo(path="a", is_dir=False, file_size=5),
)

dbutils.fs.rm(Path("a"))

inner.assert_called_with("a", recursive=False)


def test_dbutils_adds_user_agent(config):
from databricks.sdk.dbutils import RemoteDbUtils

Expand Down
Loading