diff --git a/elia_chat/__main__.py b/elia_chat/__main__.py index 87227fb..d8e4972 100644 --- a/elia_chat/__main__.py +++ b/elia_chat/__main__.py @@ -4,19 +4,20 @@ import asyncio import pathlib -from textwrap import dedent import tomllib +from pathlib import Path +from textwrap import dedent from typing import Any import click from click_default_group import DefaultGroup - from rich.console import Console from elia_chat.app import Elia from elia_chat.config import LaunchConfig -from elia_chat.database.import_chatgpt import import_chatgpt_data from elia_chat.database.database import create_database, sqlite_file_name +from elia_chat.database.export_chatgpt import export_chatgpt_data +from elia_chat.database.import_chatgpt import import_chatgpt_data from elia_chat.locations import config_file console = Console() @@ -92,14 +93,16 @@ def reset() -> None: console.print( Padding( Text.from_markup( - dedent(f"""\ + dedent( + f"""\ [u b red]Warning![/] [b red]This will delete all messages and chats.[/] You may wish to create a backup of \ "[bold blue u]{str(sqlite_file_name.resolve().absolute())}[/]" before continuing. - """) + """ + ) ), pad=(1, 2), ) @@ -110,6 +113,29 @@ def reset() -> None: console.print(f"♻️ Database reset @ {sqlite_file_name}") +@cli.command("export") +@click.option( + "-c", + "--chat", + type=str, + default="", + help="Chat identifier, either id or name", +) +@click.argument( + "file", + type=click.Path(dir_okay=False, path_type=pathlib.Path, resolve_path=True), + required=False, +) +def export_db_to_file(chat: str, file: pathlib.Path) -> None: + """ + Export ChatGPT Conversations + + This command will export a chat with all messages into a json file. + """ + exported_file = asyncio.run(export_chatgpt_data(chat=chat, file=file)) + console.print(f"[green]Chat exported as json to {str(exported_file)!r}") + + @cli.command("import") @click.argument( "file", diff --git a/elia_chat/database/export_chatgpt.py b/elia_chat/database/export_chatgpt.py new file mode 100644 index 0000000..94423f3 --- /dev/null +++ b/elia_chat/database/export_chatgpt.py @@ -0,0 +1,46 @@ +import json +from pathlib import Path +from typing import Optional + +from rich.console import Console +from rich.live import Live +from rich.text import Text + +from elia_chat.database.models import ChatDao + + +async def export_chatgpt_data(chat: str, file: Optional[Path]) -> Path: + console = Console() + + def output_progress(exported_count: int, message_count: int) -> Text: + style = "green" if exported_count == message_count else "yellow" + return Text.from_markup( + f"Exported [b]{exported_count}[/] of [b]{message_count}[/] messages.", + style=style, + ) + + chat_id: int = 0 + try: + chat_id = int(chat) + except ValueError: + all_chats = await ChatDao.all() + chat_by_name = list(filter(lambda c: c.title == chat, all_chats))[0] + chat_id = chat_by_name.id + chat_dao = await ChatDao.from_id(chat_id) + messages = chat_dao.messages + message_count = 0 + export = {"title": chat_dao.title, "id": chat_dao.id, "messages": []} + with Live(output_progress(0, len(messages))) as live: + for message in messages: + export["messages"].append(message.as_dict()) + message_count = message_count + 1 + live.update(output_progress(message_count, len(messages))) + + console.print(json.dumps(export, indent=2)) + if file is None: + file = Path(f"{chat_dao.title}_export.json") + with open(file, "w+", encoding="utf-8") as f: + json.dump(export, f, indent=2) + console.print("[green]Exported chat to json.") + + return file diff --git a/elia_chat/database/models.py b/elia_chat/database/models.py index a260ba4..ac31789 100644 --- a/elia_chat/database/models.py +++ b/elia_chat/database/models.py @@ -1,7 +1,7 @@ from datetime import datetime -from typing import Any, Optional +from typing import Any, Dict, Optional -from sqlalchemy import Column, DateTime, func, JSON, desc +from sqlalchemy import JSON, Column, DateTime, desc, func from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.orm import selectinload from sqlmodel import Field, Relationship, SQLModel, select @@ -47,6 +47,17 @@ class MessageDao(AsyncAttrs, SQLModel, table=True): model: str | None """The model that wrote this response. (Could switch models mid-chat, possibly)""" + def as_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "timestamp": self.timestamp.isoformat(), + "chat_id": self.chat_id, + "role": self.role, + "content": self.content, + "meta": self.meta, + "model": self.model, + } + class ChatDao(AsyncAttrs, SQLModel, table=True): __tablename__ = "chat"