From 7f28f3eab26829bd37475397afbe17f173b5259b Mon Sep 17 00:00:00 2001 From: Antonis Makropoulos Date: Mon, 19 Aug 2024 18:56:35 +0300 Subject: [PATCH] Read context length and warn if it is very large --- Runtime/LLM.cs | 9 +++++++-- Runtime/LLMChatTemplates.cs | 8 ++++++-- Runtime/LLMGGUF.cs | 25 ++++++++++++++++++++++--- Runtime/LLMManager.cs | 12 ++++++++++-- 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/Runtime/LLM.cs b/Runtime/LLM.cs index a691cf29..41001fb8 100644 --- a/Runtime/LLM.cs +++ b/Runtime/LLM.cs @@ -173,8 +173,13 @@ public void SetModel(string path) if (!string.IsNullOrEmpty(model)) { ModelEntry modelEntry = LLMManager.Get(model); - string template = modelEntry != null ? modelEntry.chatTemplate : ChatTemplate.FromGGUF(GetModelLoraPathRuntime(model)); - SetTemplate(template); + if (modelEntry == null) modelEntry = new ModelEntry(GetModelLoraPathRuntime(model)); + SetTemplate(modelEntry.chatTemplate); + Debug.Log(modelEntry.contextLength); + if (contextSize == 0 && modelEntry.contextLength > 32768) + { + LLMUnitySetup.LogWarning($"The model {path} has very large context size ({modelEntry.contextLength}), consider setting it to a smaller value (<=32768) to avoid filling up the RAM"); + } } #if UNITY_EDITOR if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this); diff --git a/Runtime/LLMChatTemplates.cs b/Runtime/LLMChatTemplates.cs index f6aa22cc..060b3599 100644 --- a/Runtime/LLMChatTemplates.cs +++ b/Runtime/LLMChatTemplates.cs @@ -1,5 +1,6 @@ /// @file /// @brief File implementing the chat templates. +using System; using System.Collections.Generic; using System.IO; using UnityEngine; @@ -113,9 +114,12 @@ public static string FromTemplate(string template) /// template name public static string FromGGUF(string path) { - GGUFReader reader = new GGUFReader(path); - string name; + return FromGGUF(new GGUFReader(path), path); + } + public static string FromGGUF(GGUFReader reader, string path) + { + string name; name = FromTemplate(reader.GetStringField("tokenizer.chat_template")); if (name != null) return name; diff --git a/Runtime/LLMGGUF.cs b/Runtime/LLMGGUF.cs index 4db81061..2ebeaf80 100644 --- a/Runtime/LLMGGUF.cs +++ b/Runtime/LLMGGUF.cs @@ -125,6 +125,13 @@ public ReaderField GetField(string key) return null; } + public byte[] GetGenericField(string key) + { + ReaderField field = GetField(key); + if (field == null || field.parts.Count == 0) return null; + return (byte[])field.parts[field.parts.Count - 1]; + } + /// /// Allows to retrieve a string GGUF field. /// @@ -132,9 +139,21 @@ public ReaderField GetField(string key) /// Retrieved GGUF value public string GetStringField(string key) { - ReaderField field = GetField(key); - if (field == null || field.parts.Count == 0) return null; - return System.Text.Encoding.UTF8.GetString((byte[])field.parts[field.parts.Count - 1]); + byte[] value = GetGenericField(key); + if (value == null) return null; + return System.Text.Encoding.UTF8.GetString(value); + } + + /// + /// Allows to retrieve an integer GGUF field. + /// + /// GGUF field to retrieve + /// Retrieved GGUF value + public int GetIntField(string key) + { + byte[] value = GetGenericField(key); + if (value == null) return -1; + return BitConverter.ToInt32(value, 0); } private byte[] ReadBytes(int offset, int count) diff --git a/Runtime/LLMManager.cs b/Runtime/LLMManager.cs index ff1101d8..3c7cbc24 100644 --- a/Runtime/LLMManager.cs +++ b/Runtime/LLMManager.cs @@ -17,7 +17,7 @@ public class ModelEntry public string chatTemplate; public string url; public bool includeInBuild; - + public int contextLength; public ModelEntry(string path, bool lora = false, string label = null, string url = null) { @@ -25,9 +25,17 @@ public ModelEntry(string path, bool lora = false, string label = null, string ur this.label = label == null ? filename : label; this.lora = lora; this.path = Path.GetFullPath(path).Replace('\\', '/'); - chatTemplate = lora ? null : ChatTemplate.FromGGUF(this.path); this.url = url; includeInBuild = true; + chatTemplate = null; + contextLength = -1; + if (!lora) + { + GGUFReader reader = new GGUFReader(this.path); + chatTemplate = ChatTemplate.FromGGUF(reader, this.path); + string arch = reader.GetStringField("general.architecture"); + if (arch != null) contextLength = reader.GetIntField($"{arch}.context_length"); + } } public ModelEntry OnlyRequiredFields()