diff --git a/chatblade/cli.py b/chatblade/cli.py index ddda7a8..be6b609 100644 --- a/chatblade/cli.py +++ b/chatblade/cli.py @@ -22,7 +22,7 @@ def fetch_and_cache(messages, params): else: response_msg = chat.query_chat_gpt(messages, params) messages.append(response_msg) - storage.to_cache(messages) + storage.to_cache(messages, params) return messages @@ -49,7 +49,7 @@ def handle_input(query, params): utils.debug(title="cli input", query=query, params=params) if params.last: - messages = storage.messages_from_cache() + messages = storage.messages_from_cache(params) if query: # continue conversation messages.append(chat.Message("user", query)) elif params.prompt_file: @@ -63,6 +63,9 @@ def handle_input(query, params): printer.warn("no query or option given. nothing to do...") exit() + if type(params.directory) == type(None): + del params["directory"] + if params.tokens: token_prices = chat.get_tokens_and_costs(messages) printer.print_tokens(messages, token_prices, params) diff --git a/chatblade/parser.py b/chatblade/parser.py index f5c5e33..969a9db 100644 --- a/chatblade/parser.py +++ b/chatblade/parser.py @@ -94,6 +94,12 @@ def parse(args): help="Stream the incoming text to the terminal", action="store_true", ) + parser.add_argument( + "--directory", + metavar="d", + type=str, + help="Set the chatblade cache location (default ~/.cache/chatblade, ~/Library/Caches/chatblade on osx)", + ) # ------ Display Options parser.add_argument( diff --git a/chatblade/storage.py b/chatblade/storage.py index 47a6ad9..58b5ea4 100644 --- a/chatblade/storage.py +++ b/chatblade/storage.py @@ -13,13 +13,19 @@ APP_NAME = "chatblade" -def get_cache_file_path(): +def get_cache_file_path(args): """ if ~/.cache is availabe always use ~/.cache/chatblade as the cachefile otherwise fallback to the platform recommended location and create the directory e.g. ~/Library/Caches/chatblade on osx """ - cache_path = os.path.expanduser("~/.cache") + if "directory" not in args: + cache_path = os.path.expanduser("~/.cache") + else: + cache_path = os.path.expanduser(args.directory) + if not os.path.exists(cache_path): + os.makedirs(cache_path) + if not os.path.exists(cache_path): cache_path = platformdirs.user_cache_dir(APP_NAME) if not os.path.exists(cache_path): @@ -28,15 +34,15 @@ def get_cache_file_path(): return os.path.join(cache_path, APP_NAME) -def to_cache(messages): +def to_cache(messages, args): """cache the current messages state""" - with open(get_cache_file_path(), "wb") as f: + with open(get_cache_file_path(args), "wb") as f: pickle.dump(messages, f) -def messages_from_cache(): +def messages_from_cache(args): """load messages from last state or ChatbladeError if not exists""" - file_path = get_cache_file_path() + file_path = get_cache_file_path(args) if not os.path.exists(file_path): raise errors.ChatbladeError("No last state cached from which to begin") else: