-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.go
More file actions
259 lines (214 loc) · 8.72 KB
/
main.go
File metadata and controls
259 lines (214 loc) · 8.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
// Package main provides the entry point for latent, a terminal UI application
// for visualizing text embeddings. It connects to Ollama for generating embeddings
// and Qdrant for vector storage, then projects high-dimensional vectors to 2D
// using PCA for interactive visualization.
package main
import (
"context"
"flag"
"fmt"
"os"
"github.com/alDuncanson/latent/dataimport"
"github.com/alDuncanson/latent/huggingface"
"github.com/alDuncanson/latent/ollama"
"github.com/alDuncanson/latent/preload"
"github.com/alDuncanson/latent/qdrant"
"github.com/alDuncanson/latent/tui"
tea "github.com/charmbracelet/bubbletea"
"github.com/google/uuid"
)
// version is set at build time via ldflags, defaults to "dev" for local builds
var version = "dev"
// Service configuration constants for connecting to backend services
const (
// ollamaServiceURL is the HTTP endpoint for the Ollama embedding service
ollamaServiceURL = "http://localhost:11434"
// embeddingModelName specifies which Ollama model to use for text embeddings
embeddingModelName = "nomic-embed-text"
// qdrantServiceAddress is the gRPC endpoint for the Qdrant vector database
qdrantServiceAddress = "localhost:6334"
// vectorCollectionName is the Qdrant collection where embeddings are stored
vectorCollectionName = "embeddings"
// embeddingVectorDimensions is the size of vectors produced by nomic-embed-text
embeddingVectorDimensions = 768
)
func main() {
// Parse command-line flags for version display and demo data preloading
showVersionFlag := flag.Bool("version", false, "print version and exit")
preloadDemoDataFlag := flag.Bool("preload", false, "seed with demo word list")
hfDatasetFlag := flag.String("hf-dataset", "", "Hugging Face dataset to import (e.g., cornell-movie-review-data/rotten_tomatoes)")
hfSplitFlag := flag.String("hf-split", "train", "dataset split to use (default: train)")
hfColumnFlag := flag.String("hf-column", "text", "column containing text to embed (default: text)")
hfMaxRowsFlag := flag.Int("hf-max-rows", 100, "maximum rows to fetch from Hugging Face (default: 100)")
flag.Parse()
// Handle version flag: print version and exit early
if *showVersionFlag {
fmt.Println(version)
return
}
// Check for positional argument (dataset file to import)
var datasetPath string
if flag.NArg() > 0 {
datasetPath = flag.Arg(0)
}
// Initialize the Ollama client for generating text embeddings
ollamaEmbeddingClient := ollama.NewClient(ollamaServiceURL, embeddingModelName)
// Initialize the Qdrant client for vector storage and retrieval
qdrantVectorClient, connectionError := qdrant.NewClient(
qdrantServiceAddress,
vectorCollectionName,
embeddingVectorDimensions,
)
if connectionError != nil {
fmt.Fprintf(os.Stderr, "Failed to connect to Qdrant: %v\n", connectionError)
fmt.Fprintln(os.Stderr, "Make sure Qdrant is running: docker run -p 6333:6333 -p 6334:6334 qdrant/qdrant")
os.Exit(1)
}
defer qdrantVectorClient.Close()
// If preload flag is set, seed the database with demo words before starting the UI
if *preloadDemoDataFlag {
preloadError := runPreloadDemoWords(ollamaEmbeddingClient, qdrantVectorClient)
if preloadError != nil {
fmt.Fprintf(os.Stderr, "Preload failed: %v\n", preloadError)
os.Exit(1)
}
}
// If a dataset path was provided, import it
if datasetPath != "" {
importError := runImportDataset(ollamaEmbeddingClient, qdrantVectorClient, datasetPath)
if importError != nil {
fmt.Fprintf(os.Stderr, "Import failed: %v\n", importError)
os.Exit(1)
}
}
// If a Hugging Face dataset was specified, fetch and import it
if *hfDatasetFlag != "" {
importError := runImportHuggingFace(
ollamaEmbeddingClient,
qdrantVectorClient,
*hfDatasetFlag,
*hfSplitFlag,
*hfColumnFlag,
*hfMaxRowsFlag,
)
if importError != nil {
fmt.Fprintf(os.Stderr, "Hugging Face import failed: %v\n", importError)
os.Exit(1)
}
}
// Create and run the terminal user interface
terminalUserInterfaceModel := tui.NewModel(ollamaEmbeddingClient, qdrantVectorClient, version)
bubbleTeaProgram := tea.NewProgram(terminalUserInterfaceModel, tea.WithAltScreen())
_, programRunError := bubbleTeaProgram.Run()
if programRunError != nil {
fmt.Fprintf(os.Stderr, "Error running program: %v\n", programRunError)
os.Exit(1)
}
}
// runPreloadDemoWords seeds the Qdrant database with a predefined list of demo words.
// It generates embeddings for each word using Ollama and stores them in Qdrant.
// Progress is displayed to stdout as each word is processed.
func runPreloadDemoWords(ollamaEmbeddingClient *ollama.Client, qdrantVectorClient *qdrant.Client) error {
demoWordList := preload.Words()
backgroundContext := context.Background()
fmt.Printf("Preloading %d words...\n", len(demoWordList))
// Process each word: generate embedding and store in vector database
for wordIndex, currentWord := range demoWordList {
// Generate the embedding vector for the current word
embeddingVector, embeddingError := ollamaEmbeddingClient.Embed(currentWord)
if embeddingError != nil {
return fmt.Errorf("embed %q: %w", currentWord, embeddingError)
}
// Store the word and its embedding in Qdrant with a unique identifier
uniquePointIdentifier := uuid.New().String()
upsertError := qdrantVectorClient.Upsert(backgroundContext, uniquePointIdentifier, currentWord, embeddingVector)
if upsertError != nil {
return fmt.Errorf("upsert %q: %w", currentWord, upsertError)
}
// Display progress on the same line using carriage return
fmt.Printf("\r[%d/%d] %s", wordIndex+1, len(demoWordList), currentWord)
}
fmt.Println("\nDone.")
return nil
}
// runImportDataset loads texts from a CSV or JSON file and embeds them into Qdrant.
func runImportDataset(ollamaEmbeddingClient *ollama.Client, qdrantVectorClient *qdrant.Client, datasetPath string) error {
texts, loadError := dataimport.LoadTexts(datasetPath)
if loadError != nil {
return fmt.Errorf("loading dataset: %w", loadError)
}
if len(texts) == 0 {
return fmt.Errorf("no texts found in dataset")
}
backgroundContext := context.Background()
fmt.Printf("Importing %d texts from %s...\n", len(texts), datasetPath)
for textIndex, currentText := range texts {
embeddingVector, embeddingError := ollamaEmbeddingClient.Embed(currentText)
if embeddingError != nil {
return fmt.Errorf("embed %q: %w", currentText, embeddingError)
}
uniquePointIdentifier := uuid.New().String()
upsertError := qdrantVectorClient.Upsert(backgroundContext, uniquePointIdentifier, currentText, embeddingVector)
if upsertError != nil {
return fmt.Errorf("upsert %q: %w", currentText, upsertError)
}
fmt.Printf("\r[%d/%d] %s", textIndex+1, len(texts), truncateForProgress(currentText, 40))
}
fmt.Println("\nDone.")
return nil
}
func truncateForProgress(text string, maxLen int) string {
if len(text) <= maxLen {
return text
}
return text[:maxLen-3] + "..."
}
// runImportHuggingFace fetches texts from a Hugging Face dataset and embeds them into Qdrant.
func runImportHuggingFace(ollamaEmbeddingClient *ollama.Client, qdrantVectorClient *qdrant.Client, dataset, split, column string, maxRows int) error {
hfClient := huggingface.NewClient()
// First, get splits to determine the config
splits, splitsError := hfClient.GetSplits(dataset)
if splitsError != nil {
return fmt.Errorf("fetching splits: %w", splitsError)
}
if len(splits.Splits) == 0 {
return fmt.Errorf("no splits found for dataset %s", dataset)
}
// Find the matching split and get its config
var config string
for _, s := range splits.Splits {
if s.Split == split {
config = s.Config
break
}
}
if config == "" {
config = splits.Splits[0].Config
split = splits.Splits[0].Split
fmt.Printf("Split not found, using %s/%s\n", config, split)
}
fmt.Printf("Fetching from Hugging Face: %s (config=%s, split=%s, column=%s)\n", dataset, config, split, column)
texts, fetchError := hfClient.FetchTexts(dataset, config, split, column, maxRows)
if fetchError != nil {
return fmt.Errorf("fetching texts: %w", fetchError)
}
if len(texts) == 0 {
return fmt.Errorf("no texts found in column %q", column)
}
backgroundContext := context.Background()
fmt.Printf("Importing %d texts...\n", len(texts))
for textIndex, currentText := range texts {
embeddingVector, embeddingError := ollamaEmbeddingClient.Embed(currentText)
if embeddingError != nil {
return fmt.Errorf("embed %q: %w", truncateForProgress(currentText, 20), embeddingError)
}
uniquePointIdentifier := uuid.New().String()
upsertError := qdrantVectorClient.Upsert(backgroundContext, uniquePointIdentifier, currentText, embeddingVector)
if upsertError != nil {
return fmt.Errorf("upsert: %w", upsertError)
}
fmt.Printf("\r[%d/%d] %s", textIndex+1, len(texts), truncateForProgress(currentText, 40))
}
fmt.Println("\nDone.")
return nil
}