diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 0994b0603..a3ebc490e 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -8,6 +8,8 @@ ### Bug Fixes +* `dbutils.fs` methods now accept `pathlib.Path` arguments in addition to strings ([#1461](https://github.com/databricks/databricks-sdk-py/pull/1461)). + ### Documentation ### Breaking Changes diff --git a/databricks/sdk/dbutils.py b/databricks/sdk/dbutils.py index b69668950..f7312ec65 100644 --- a/databricks/sdk/dbutils.py +++ b/databricks/sdk/dbutils.py @@ -7,7 +7,8 @@ import threading from collections import namedtuple from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union from .core import ApiClient, Config, DatabricksError from .mixins import compute as compute_ext @@ -17,6 +18,11 @@ _LOG = logging.getLogger("databricks.sdk") +def _as_str(path: Union[str, Path]) -> str: + """Convert a pathlib.Path to str; leave plain strings untouched.""" + return str(path) if isinstance(path, Path) else path + + class FileInfo(namedtuple("FileInfo", ["path", "name", "size", "modificationTime"])): pass @@ -45,17 +51,17 @@ def __init__( self._dbfs = dbfs_ext self._proxy_factory = proxy_factory - def cp(self, from_: str, to: str, recurse: bool = False) -> bool: + def cp(self, from_: Union[str, Path], to: Union[str, Path], recurse: bool = False) -> bool: """Copies a file or directory, possibly across FileSystems""" - self._dbfs.copy(from_, to, recursive=recurse) + self._dbfs.copy(_as_str(from_), _as_str(to), recursive=recurse) return True - def head(self, file: str, maxBytes: int = 65536) -> str: + def head(self, file: Union[str, Path], maxBytes: int = 65536) -> str: """Returns up to the first 'maxBytes' bytes of the given file as a String encoded in UTF-8""" - with self._dbfs.download(file) as f: + with self._dbfs.download(_as_str(file)) as f: return f.read(maxBytes).decode("utf8") - def ls(self, dir: str) -> List[FileInfo]: + def ls(self, dir: Union[str, Path]) -> List[FileInfo]: """Lists the contents of a directory""" return [ FileInfo( @@ -64,34 +70,34 @@ def ls(self, dir: str) -> List[FileInfo]: f.file_size, f.modification_time, ) - for f in self._dbfs.list(dir) + for f in self._dbfs.list(_as_str(dir)) ] - def mkdirs(self, dir: str) -> bool: + def mkdirs(self, dir: Union[str, Path]) -> bool: """Creates the given directory if it does not exist, also creating any necessary parent directories""" - self._dbfs.mkdirs(dir) + self._dbfs.mkdirs(_as_str(dir)) return True - def mv(self, from_: str, to: str, recurse: bool = False) -> bool: + def mv(self, from_: Union[str, Path], to: Union[str, Path], recurse: bool = False) -> bool: """Moves a file or directory, possibly across FileSystems""" - self._dbfs.move_(from_, to, recursive=recurse, overwrite=True) + self._dbfs.move_(_as_str(from_), _as_str(to), recursive=recurse, overwrite=True) return True - def put(self, file: str, contents: str, overwrite: bool = False) -> bool: + def put(self, file: Union[str, Path], contents: str, overwrite: bool = False) -> bool: """Writes the given String out to a file, encoded in UTF-8""" - with self._dbfs.open(file, write=True, overwrite=overwrite) as f: + with self._dbfs.open(_as_str(file), write=True, overwrite=overwrite) as f: f.write(contents.encode("utf8")) return True - def rm(self, dir: str, recurse: bool = False) -> bool: + def rm(self, dir: Union[str, Path], recurse: bool = False) -> bool: """Removes a file or directory""" - self._dbfs.delete(dir, recursive=recurse) + self._dbfs.delete(_as_str(dir), recursive=recurse) return True def mount( self, - source: str, - mount_point: str, + source: Union[str, Path], + mount_point: Union[str, Path], encryption_type: str = None, owner: str = None, extra_configs: Dict[str, str] = None, @@ -105,17 +111,17 @@ def mount( kwargs["owner"] = owner if extra_configs: kwargs["extra_configs"] = extra_configs - return fs.mount(source, mount_point, **kwargs) + return fs.mount(_as_str(source), _as_str(mount_point), **kwargs) - def unmount(self, mount_point: str) -> bool: + def unmount(self, mount_point: Union[str, Path]) -> bool: """Deletes a DBFS mount point""" fs = self._proxy_factory("fs") - return fs.unmount(mount_point) + return fs.unmount(_as_str(mount_point)) def updateMount( self, - source: str, - mount_point: str, + source: Union[str, Path], + mount_point: Union[str, Path], encryption_type: str = None, owner: str = None, extra_configs: Dict[str, str] = None, @@ -129,7 +135,7 @@ def updateMount( kwargs["owner"] = owner if extra_configs: kwargs["extra_configs"] = extra_configs - return fs.updateMount(source, mount_point, **kwargs) + return fs.updateMount(_as_str(source), _as_str(mount_point), **kwargs) def mounts(self) -> List[MountInfo]: """Displays information about what is mounted within DBFS""" diff --git a/tests/test_dbutils.py b/tests/test_dbutils.py index 9792d8de5..50e591e34 100644 --- a/tests/test_dbutils.py +++ b/tests/test_dbutils.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest as pytest from databricks.sdk.dbutils import FileInfo as DBUtilsFileInfo @@ -307,6 +309,99 @@ def test_dbutils_credentials_get_service_credential_provider(config, mocker): dbutils.credentials.getServiceCredentialsProvider("creds") +def test_fs_ls_accepts_path_object(dbutils, mocker): + inner = mocker.patch( + "databricks.sdk.service.files.DbfsAPI.list", + return_value=[FileInfo(path="a/b", file_size=10, modification_time=20)], + ) + mocker.patch( + "databricks.sdk.service.files.DbfsAPI.get_status", + side_effect=[ + FileInfo(path="a", is_dir=True, file_size=5), + FileInfo(path="a/b", is_dir=False, file_size=5), + ], + ) + + result = dbutils.fs.ls(Path("a")) + + inner.assert_called_with("a") + assert len(result) == 1 + assert result[0] == DBUtilsFileInfo("a/b", "b", 10, 20) + + +def test_fs_cp_accepts_path_objects(dbutils, mocker): + inner = mocker.patch("databricks.sdk.mixins.files.DbfsExt.copy") + + dbutils.fs.cp(Path("a"), Path("b"), recurse=True) + + inner.assert_called_with("a", "b", recursive=True) + + +def test_fs_head_accepts_path_object(dbutils, mocker): + mocker.patch( + "databricks.sdk.service.files.DbfsAPI.read", + return_value=ReadResponse(data="aGVsbG8=", bytes_read=5), + ) + mocker.patch( + "databricks.sdk.service.files.DbfsAPI.get_status", + return_value=FileInfo(path="a", is_dir=False, file_size=5), + ) + + result = dbutils.fs.head(Path("a")) + + assert result == "hello" + + +def test_fs_mkdirs_accepts_path_object(dbutils, mocker): + inner = mocker.patch("databricks.sdk.service.files.DbfsAPI.mkdirs") + + dbutils.fs.mkdirs(Path("a")) + + inner.assert_called_with("a") + + +def test_fs_mv_accepts_path_objects(dbutils, mocker): + inner = mocker.patch("databricks.sdk.mixins.files.DbfsExt.move_") + + dbutils.fs.mv(Path("a"), Path("b")) + + inner.assert_called_with("a", "b", recursive=False, overwrite=True) + + +def test_fs_put_accepts_path_object(dbutils, mocker): + class _MockOpen: + _written = None + + def __enter__(self): + return self + + def __exit__(self, *ignored): + pass + + def write(self, contents): + self._written = contents + + mock_open = _MockOpen() + inner = mocker.patch("databricks.sdk.mixins.files.DbfsExt.open", return_value=mock_open) + + dbutils.fs.put(Path("a"), "b") + + inner.assert_called_with("a", overwrite=False, write=True) + assert mock_open._written == b"b" + + +def test_fs_rm_accepts_path_object(dbutils, mocker): + inner = mocker.patch("databricks.sdk.service.files.DbfsAPI.delete") + mocker.patch( + "databricks.sdk.service.files.DbfsAPI.get_status", + return_value=FileInfo(path="a", is_dir=False, file_size=5), + ) + + dbutils.fs.rm(Path("a")) + + inner.assert_called_with("a", recursive=False) + + def test_dbutils_adds_user_agent(config): from databricks.sdk.dbutils import RemoteDbUtils