Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/notebooklm/_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,8 @@ async def generate_mind_map(
self,
notebook_id: str,
source_ids: builtins.list[str] | None = None,
language: str = "en",
instructions: str | None = None,
) -> dict[str, Any]:
"""Generate an interactive mind map.

Expand All @@ -997,6 +999,8 @@ async def generate_mind_map(
Args:
notebook_id: The notebook ID.
source_ids: Source IDs to include. If None, uses all sources.
language: Output language code.
instructions: Optional generation instructions.

Returns:
Dictionary with 'mind_map' (JSON data) and 'note_id'.
Expand All @@ -1014,7 +1018,7 @@ async def generate_mind_map(
None,
None,
None,
["interactive_mindmap", [["[CONTEXT]", ""]], ""],
["interactive_mindmap", [["[CONTEXT]", instructions or ""]], language],
None,
[2, None, [1]],
]
Expand Down
9 changes: 3 additions & 6 deletions src/notebooklm/_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SourceNotFoundError,
SourceProcessingError,
SourceTimeoutError,
_extract_source_url,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,12 +103,8 @@ async def list(self, notebook_id: str) -> list[Source]:
src_id = src[0][0] if isinstance(src[0], list) else src[0]
title = src[1] if len(src) > 1 else None

# Extract URL if present (at src[2][7])
url = None
if len(src) > 2 and isinstance(src[2], list) and len(src[2]) > 7:
url_list = src[2][7]
if isinstance(url_list, list) and len(url_list) > 0:
url = url_list[0]
# Extract URL if present
url = _extract_source_url(src[2] if len(src) > 2 else None)

# Extract timestamp from src[2][2] - [seconds, nanoseconds]
created_at = None
Expand Down
22 changes: 19 additions & 3 deletions src/notebooklm/cli/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,38 +987,54 @@ async def _generate():


@generate.command("mind-map")
@click.argument("description", default="", required=False)
@click.option(
"-n",
"--notebook",
"notebook_id",
default=None,
help="Notebook ID (uses current if not set)",
)
@click.option("--language", default=None, help="Output language (default: from config or 'en')")
@click.option("--source", "-s", "source_ids", multiple=True, help="Limit to specific source IDs")
@json_option
@with_client
def generate_mind_map(ctx, notebook_id, source_ids, json_output, client_auth):
def generate_mind_map(
ctx, description, notebook_id, language, source_ids, json_output, client_auth
):
"""Generate mind map.

\b
Use --json for machine-readable output.

\b
Example:
notebooklm generate mind-map "focus on chronology" --language zh_Hans
"""
nb_id = require_notebook(notebook_id)

async def _run():
async with NotebookLMClient(client_auth) as client:
nb_id_resolved = await resolve_notebook_id(client, nb_id)
sources = await resolve_source_ids(client, nb_id_resolved, source_ids)
resolved_language = resolve_language(language)
instructions = description or None

# Show status spinner only for console output
if json_output:
result = await client.artifacts.generate_mind_map(
nb_id_resolved, source_ids=sources
nb_id_resolved,
source_ids=sources,
language=resolved_language,
instructions=instructions,
)
else:
with console.status("Generating mind map..."):
result = await client.artifacts.generate_mind_map(
nb_id_resolved, source_ids=sources
nb_id_resolved,
source_ids=sources,
language=resolved_language,
instructions=instructions,
)

_output_mind_map_result(result, json_output)
Expand Down
47 changes: 35 additions & 12 deletions src/notebooklm/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,39 @@ def _safe_source_type(type_code: int | None) -> SourceType:
return result


def _extract_source_url(metadata: list[Any] | None) -> str | None:
"""Extract a source URL from NotebookLM source metadata.

NotebookLM stores URLs in different slots depending on the source type:
- metadata[7][0] for web/PDF-style sources
- metadata[5][0] for YouTube sources
- metadata[0] as a fallback in some nested response shapes
"""
if not isinstance(metadata, list):
return None

if len(metadata) > 7:
url_list = metadata[7]
if isinstance(url_list, list) and url_list:
first = url_list[0]
if isinstance(first, str) and first:
return first

if len(metadata) > 5:
youtube_data = metadata[5]
if isinstance(youtube_data, list) and youtube_data:
first = youtube_data[0]
if isinstance(first, str) and first:
return first

if metadata:
first = metadata[0]
if isinstance(first, str) and first.startswith("http"):
return first

return None


def _map_artifact_kind(artifact_type: int, variant: int | None) -> ArtifactType:
"""Convert internal artifact type and variant to user-facing ArtifactType.

Expand Down Expand Up @@ -582,25 +615,15 @@ def from_api_response(cls, data: list[Any], notebook_id: str | None = None) -> "
source_id = entry[0][0] if isinstance(entry[0], list) else entry[0]
title = entry[1] if len(entry) > 1 else None

# Try to extract URL if present
url = None
if len(entry) > 2 and isinstance(entry[2], list):
if len(entry[2]) > 7 and isinstance(entry[2][7], list):
url = entry[2][7][0] if entry[2][7] else None
url = _extract_source_url(entry[2] if len(entry) > 2 else None)

return cls(id=str(source_id), title=title, url=url, _type_code=None)

# Deeply nested: continue with URL and type code extraction
url = None
type_code = None
if len(entry) > 2 and isinstance(entry[2], list):
if len(entry[2]) > 7:
url_list = entry[2][7]
if isinstance(url_list, list) and len(url_list) > 0:
url = url_list[0]
if not url and len(entry[2]) > 0:
if isinstance(entry[2][0], str) and entry[2][0].startswith("http"):
url = entry[2][0]
url = _extract_source_url(entry[2])
# Extract type code at entry[2][4] if available
if len(entry[2]) > 4 and isinstance(entry[2][4], int):
type_code = entry[2][4]
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/cli/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,36 @@ def test_generate_mind_map(self, runner, mock_auth):

assert result.exit_code == 0

def test_generate_mind_map_with_language(self, runner, mock_auth):
with patch_client_for_module("generate") as mock_client_cls:
mock_client = create_mock_client()
mock_client.artifacts.generate_mind_map = AsyncMock(
return_value={"mind_map": {"name": "Root", "children": []}, "note_id": "n1"}
)
mock_client_cls.return_value = mock_client

with patch("notebooklm.cli.helpers.fetch_tokens", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = ("csrf", "session")
result = runner.invoke(
cli,
[
"generate",
"mind-map",
"--language",
"zh_Hans",
"-n",
"nb_123",
],
)

assert result.exit_code == 0
mock_client.artifacts.generate_mind_map.assert_awaited_once_with(
"nb_123",
source_ids=None,
language="zh_Hans",
instructions=None,
)


# =============================================================================
# GENERATE REPORT TESTS
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/test_source_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,30 @@ async def test_generate_mind_map_source_encoding(self, mock_core, mock_notes_api

assert source_ids_nested == [[["src_mm_1"]], [["src_mm_2"]]]

@pytest.mark.asyncio
async def test_generate_mind_map_includes_language_and_instructions(
self, mock_core, mock_notes_api
):
"""Test generate_mind_map encodes language and instructions in config."""
api = ArtifactsAPI(mock_core, mock_notes_api)

mock_core.rpc_call.return_value = [['{"name": "Mind Map", "children": []}']]

await api.generate_mind_map(
notebook_id="nb_123",
source_ids=["src_mm_1"],
language="zh_Hans",
instructions="Focus on chronology",
)

params = mock_core.rpc_call.call_args.args[1]

assert params[5] == [
"interactive_mindmap",
[["[CONTEXT]", "Focus on chronology"]],
"zh_Hans",
]

@pytest.mark.asyncio
async def test_suggest_reports_uses_get_suggested_reports(self, mock_core, mock_notes_api):
"""Test suggest_reports uses GET_SUGGESTED_REPORTS RPC."""
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/test_source_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,48 @@ async def mock_sleep(delay):
assert sleep_intervals[1] >= sleep_intervals[0] * 1.5


class TestSourceListParsing:
"""Tests for parsing source metadata in SourcesAPI.list()."""

@pytest.fixture
def sources_api(self):
core = MagicMock()
core.rpc_call = AsyncMock()
return SourcesAPI(core)

@pytest.mark.asyncio
async def test_list_extracts_youtube_url_from_youtube_metadata_slot(self, sources_api):
"""YouTube sources should read their URL from metadata[5][0]."""
sources_api._core.rpc_call.return_value = [
[
None,
[
[
["src_yt"],
"YouTube Video",
[
None,
None,
None,
None,
9,
["https://youtube.com/watch?v=abc", "abc", "Channel"],
None,
None,
],
[None, SourceStatus.READY],
]
]
]
]

sources = await sources_api.list("nb_123")

assert len(sources) == 1
assert sources[0].url == "https://youtube.com/watch?v=abc"
assert sources[0].kind == "youtube"


class TestWaitForSources:
"""Tests for wait_for_sources method."""

Expand Down
Loading