Skip to content

Read context length and warn if it is very large #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions Runtime/LLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,13 @@ public void SetModel(string path)
if (!string.IsNullOrEmpty(model))
{
ModelEntry modelEntry = LLMManager.Get(model);
string template = modelEntry != null ? modelEntry.chatTemplate : ChatTemplate.FromGGUF(GetModelLoraPathRuntime(model));
SetTemplate(template);
if (modelEntry == null) modelEntry = new ModelEntry(GetModelLoraPathRuntime(model));
SetTemplate(modelEntry.chatTemplate);
Debug.Log(modelEntry.contextLength);
if (contextSize == 0 && modelEntry.contextLength > 32768)
{
LLMUnitySetup.LogWarning($"The model {path} has very large context size ({modelEntry.contextLength}), consider setting it to a smaller value (<=32768) to avoid filling up the RAM");
}
}
#if UNITY_EDITOR
if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this);
Expand Down
8 changes: 6 additions & 2 deletions Runtime/LLMChatTemplates.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/// @file
/// @brief File implementing the chat templates.
using System;
using System.Collections.Generic;
using System.IO;
using UnityEngine;
Expand Down Expand Up @@ -113,9 +114,12 @@ public static string FromTemplate(string template)
/// <returns>template name</returns>
public static string FromGGUF(string path)
{
GGUFReader reader = new GGUFReader(path);
string name;
return FromGGUF(new GGUFReader(path), path);
}

public static string FromGGUF(GGUFReader reader, string path)
{
string name;
name = FromTemplate(reader.GetStringField("tokenizer.chat_template"));
if (name != null) return name;

Expand Down
25 changes: 22 additions & 3 deletions Runtime/LLMGGUF.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,35 @@ public ReaderField GetField(string key)
return null;
}

public byte[] GetGenericField(string key)
{
ReaderField field = GetField(key);
if (field == null || field.parts.Count == 0) return null;
return (byte[])field.parts[field.parts.Count - 1];
}

/// <summary>
/// Allows to retrieve a string GGUF field.
/// </summary>
/// <param name="key"> GGUF field to retrieve </param>
/// <returns> Retrieved GGUF value </returns>
public string GetStringField(string key)
{
ReaderField field = GetField(key);
if (field == null || field.parts.Count == 0) return null;
return System.Text.Encoding.UTF8.GetString((byte[])field.parts[field.parts.Count - 1]);
byte[] value = GetGenericField(key);
if (value == null) return null;
return System.Text.Encoding.UTF8.GetString(value);
}

/// <summary>
/// Allows to retrieve an integer GGUF field.
/// </summary>
/// <param name="key"> GGUF field to retrieve </param>
/// <returns> Retrieved GGUF value </returns>
public int GetIntField(string key)
{
byte[] value = GetGenericField(key);
if (value == null) return -1;
return BitConverter.ToInt32(value, 0);
}

private byte[] ReadBytes(int offset, int count)
Expand Down
12 changes: 10 additions & 2 deletions Runtime/LLMManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,25 @@ public class ModelEntry
public string chatTemplate;
public string url;
public bool includeInBuild;

public int contextLength;

public ModelEntry(string path, bool lora = false, string label = null, string url = null)
{
filename = Path.GetFileName(path);
this.label = label == null ? filename : label;
this.lora = lora;
this.path = Path.GetFullPath(path).Replace('\\', '/');
chatTemplate = lora ? null : ChatTemplate.FromGGUF(this.path);
this.url = url;
includeInBuild = true;
chatTemplate = null;
contextLength = -1;
if (!lora)
{
GGUFReader reader = new GGUFReader(this.path);
chatTemplate = ChatTemplate.FromGGUF(reader, this.path);
string arch = reader.GetStringField("general.architecture");
if (arch != null) contextLength = reader.GetIntField($"{arch}.context_length");
}
}

public ModelEntry OnlyRequiredFields()
Expand Down
Loading