Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions src/notebooklm/cli/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

import asyncio
import re
from collections.abc import Awaitable
from pathlib import Path
from typing import Literal

import click
from rich.table import Table

from .._url_utils import is_youtube_url
from ..auth import AuthTokens
from ..client import NotebookLMClient
from ..types import source_status_to_str
from .helpers import (
Expand Down Expand Up @@ -230,9 +233,26 @@ async def _run():
)
@click.option("--title", help="Title for text sources")
@click.option("--mime-type", help="MIME type for file sources")
@click.option(
"--timeout",
default=30,
type=click.IntRange(min=1),
show_default=True,
help="HTTP request timeout in seconds for adding the source",
)
@click.option("--json", "json_output", is_flag=True, help="Output as JSON")
@with_client
def source_add(ctx, content, notebook_id, source_type, title, mime_type, json_output, client_auth):
def source_add(
ctx: click.Context,
content: str,
notebook_id: str | None,
source_type: Literal["url", "text", "file", "youtube"] | None,
title: str | None,
mime_type: str | None,
timeout: int,
json_output: bool,
client_auth: AuthTokens,
) -> Awaitable[None]:
"""Add a source to a notebook.

\b
Expand Down Expand Up @@ -272,7 +292,7 @@ def source_add(ctx, content, notebook_id, source_type, title, mime_type, json_ou
file_title = title or "Pasted Text"

async def _run():
async with NotebookLMClient(client_auth) as client:
async with NotebookLMClient(client_auth, timeout=float(timeout)) as client:
nb_id_resolved = await resolve_notebook_id(client, nb_id)
if detected_type == "url" or detected_type == "youtube":
src = await client.sources.add_url(nb_id_resolved, content)
Expand Down Expand Up @@ -540,15 +560,30 @@ async def _run():
help="Search mode (default: fast)",
)
@click.option("--import-all", is_flag=True, help="Import all found sources")
@click.option(
"--timeout",
default=1800,
type=click.IntRange(min=1),
show_default=True,
help="Maximum seconds for --import-all source import retries after research completes",
)
@click.option(
"--no-wait",
is_flag=True,
help="Start research and return immediately (use 'research status/wait' to monitor)",
)
@with_client
def source_add_research(
ctx, query, notebook_id, search_source, mode, import_all, no_wait, client_auth
):
ctx: click.Context,
query: str,
notebook_id: str | None,
search_source: Literal["web", "drive"],
mode: Literal["fast", "deep"],
import_all: bool,
timeout: int,
no_wait: bool,
client_auth: AuthTokens,
) -> Awaitable[None]:
"""Search web or drive and add sources from results.

\b
Expand Down Expand Up @@ -606,6 +641,7 @@ async def _run():
nb_id_resolved,
task_id,
sources,
max_elapsed=float(timeout),
)
console.print(f"[green]Imported {len(imported)} sources[/green]")
else:
Expand Down
80 changes: 80 additions & 0 deletions tests/unit/cli/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,34 @@ def test_source_add_json_output(self, runner, mock_auth):
data = json.loads(result.output)
assert data["source"]["id"] == "src_new"

def test_source_add_passes_timeout_to_client(self, runner, mock_auth):
with patch_client_for_module("source") as mock_client_cls:
mock_client = create_mock_client()
mock_client.sources.add_url = AsyncMock(
return_value=Source(
id="src_new",
title="Example",
url="https://example.com",
)
)
mock_client_cls.return_value = mock_client

with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli,
["source", "add", "https://example.com", "-n", "nb_123", "--timeout", "90"],
)

assert result.exit_code == 0
assert mock_client_cls.call_args.kwargs["timeout"] == 90.0

def test_source_add_rejects_non_positive_timeout(self, runner, mock_auth):
result = runner.invoke(cli, ["source", "add", "https://example.com", "--timeout", "0"])

assert result.exit_code == 2
assert "x>=1" in result.output


# =============================================================================
# SOURCE GET TESTS
Expand Down Expand Up @@ -666,8 +694,60 @@ def test_add_research_with_import_all_uses_retry_helper(self, runner, mock_auth)
"nb_123",
"task_123",
[{"title": "Source 1", "url": "http://example.com"}],
max_elapsed=1800.0,
)

def test_add_research_with_import_all_passes_timeout_budget(self, runner, mock_auth):
with (
patch_client_for_module("source") as mock_client_cls,
patch.object(source_module, "import_with_retry", new_callable=AsyncMock) as mock_import,
):
mock_client = create_mock_client()
mock_client.research.start = AsyncMock(return_value={"task_id": "task_123"})
mock_client.research.poll = AsyncMock(
return_value={
"status": "completed",
"task_id": "task_123",
"sources": [{"title": "Source 1", "url": "http://example.com"}],
"report": "# Report",
}
)
mock_import.return_value = [{"id": "src_1", "title": "Source 1"}]
mock_client_cls.return_value = mock_client

with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli,
[
"source",
"add-research",
"AI papers",
"--mode",
"deep",
"--import-all",
"--timeout",
"90",
"-n",
"nb_123",
],
)

assert result.exit_code == 0
mock_import.assert_awaited_once_with(
mock_client,
"nb_123",
"task_123",
[{"title": "Source 1", "url": "http://example.com"}],
max_elapsed=90.0,
)

def test_add_research_rejects_non_positive_timeout(self, runner, mock_auth):
result = runner.invoke(cli, ["source", "add-research", "AI papers", "--timeout", "0"])

assert result.exit_code == 2
assert "x>=1" in result.output


# =============================================================================
# COMMAND EXISTENCE TESTS
Expand Down
Loading