Skip to content

Commit 83d2add

Browse files
feat: enable user to choose from multiple candidate commit messages
1 parent ddd7a1c commit 83d2add

File tree

1 file changed

+168
-57
lines changed

1 file changed

+168
-57
lines changed

main.go

Lines changed: 168 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
// Command gitsloth generates a Conventional Commit message from staged Git changes
22
// using the OpenAI API, asks for user confirmation, and creates the commit.
33
//
4+
// Usage:
5+
//
6+
// gitsloth [-a | --all] Generate one commit message and confirm
7+
// gitsloth list [-n | --num <count>] Generate N commit messages and pick one
8+
//
49
// It requires:
510
// - Being inside a Git repository
611
// - OPENAI_API_KEY environment variable set
@@ -22,14 +27,27 @@ import (
2227
"time"
2328
)
2429

25-
// main is the entry point of the CLI tool. It validates the environment,
26-
// generates a commit message from the staged diff, asks for confirmation,
27-
// and creates the commit.
30+
// main is the entry point of the CLI tool. It dispatches to either the
31+
// default single-message flow or the list subcommand based on os.Args.
2832
func main() {
33+
// Detect subcommand before flag parsing so that each subcommand
34+
// can own its own FlagSet without polluting the global one.
35+
if len(os.Args) > 1 && os.Args[1] == "list" {
36+
runList(os.Args[2:])
37+
return
38+
}
39+
40+
runDefault(os.Args[1:])
41+
}
42+
43+
// runDefault is the original single-message flow.
44+
// Flags: -a / --all (stage all changes before committing)
45+
func runDefault(args []string) {
46+
fs := flag.NewFlagSet("gitsloth", flag.ExitOnError)
2947
var all bool
30-
flag.BoolVar(&all, "all", false, "stage all changes before commiting")
31-
flag.BoolVar(&all, "a", false, "stage all changes before commiting (shorthand)")
32-
flag.Parse()
48+
fs.BoolVar(&all, "all", false, "stage all changes before committing")
49+
fs.BoolVar(&all, "a", false, "stage all changes before committing (shorthand)")
50+
fs.Parse(args) //nolint:errcheck // ExitOnError handles this
3351

3452
if !isGitRepoHere() {
3553
fmt.Println("Not inside a Git repository (.git not found here)")
@@ -43,35 +61,84 @@ func main() {
4361
}
4462
}
4563

46-
// Build structured Git context instead of relying on raw diff only.
4764
ctx, err := buildGitContext()
4865
if err != nil {
4966
fmt.Println("Failed to build git context:", err)
5067
os.Exit(1)
5168
}
5269

53-
// Ensure there are actual staged changes before proceeding.
5470
if strings.TrimSpace(ctx.Diff) == "" {
5571
fmt.Println("No changes to commit")
5672
os.Exit(0)
5773
}
5874

59-
message, err := generateCommitMessage(*ctx)
75+
messages, err := generateCommitMessages(*ctx, 1)
6076
if err != nil {
61-
fmt.Println("Failed to generate the commit message", err)
77+
fmt.Println("Failed to generate the commit message:", err)
6278
os.Exit(1)
63-
} else if message == "" {
79+
}
80+
if len(messages) == 0 || messages[0] == "" {
6481
fmt.Println("Commit message is empty")
6582
os.Exit(1)
6683
}
6784

85+
message := messages[0]
86+
6887
if !askForConfirmation(message) {
6988
fmt.Println("Commit aborted")
7089
os.Exit(0)
7190
}
7291

73-
err = createCommit(message)
92+
if err := createCommit(message); err != nil {
93+
fmt.Println(err)
94+
os.Exit(1)
95+
}
96+
}
97+
98+
// runList is the list subcommand: it generates N candidate commit messages,
99+
// lets the user pick one interactively, then creates the commit.
100+
// Flags: -n / --num (number of messages to generate, default 5)
101+
func runList(args []string) {
102+
fs := flag.NewFlagSet("gitsloth list", flag.ExitOnError)
103+
var num int
104+
fs.IntVar(&num, "num", 5, "number of commit messages to generate")
105+
fs.IntVar(&num, "n", 5, "number of commit messages to generate (shorthand)")
106+
fs.Parse(args) //nolint:errcheck // ExitOnError handles this
107+
108+
if num < 1 {
109+
fmt.Println("--num must be at least 1")
110+
os.Exit(1)
111+
}
112+
113+
if !isGitRepoHere() {
114+
fmt.Println("Not inside a Git repository (.git not found here)")
115+
os.Exit(1)
116+
}
117+
118+
ctx, err := buildGitContext()
74119
if err != nil {
120+
fmt.Println("Failed to build git context:", err)
121+
os.Exit(1)
122+
}
123+
124+
if strings.TrimSpace(ctx.Diff) == "" {
125+
fmt.Println("No changes to commit")
126+
os.Exit(0)
127+
}
128+
129+
messages, err := generateCommitMessages(*ctx, num)
130+
if err != nil {
131+
fmt.Println("Failed to generate commit messages:", err)
132+
os.Exit(1)
133+
}
134+
if len(messages) == 0 {
135+
fmt.Println("No commit messages were generated")
136+
os.Exit(1)
137+
}
138+
139+
chosen := chooseAnOption(messages)
140+
141+
if err := createCommit(chosen); err != nil {
75142
fmt.Println(err)
76143
os.Exit(1)
77144
}
@@ -214,27 +281,37 @@ feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert
214281
- Use imperative mood (e.g., "add", "fix", not "added", "fixes")
215282
`
216283

217-
// generateCommitMessage uses the OpenAI HTTP API to generate a
218-
// Conventional Commit message based on the provided Git context.
284+
// generateCommitMessages uses the OpenAI HTTP API to generate one or more
285+
// Conventional Commit messages based on the provided Git context.
286+
//
287+
// When count is 1 the model returns a plain string response.
288+
// When count > 1 it uses JSON mode, asking the model to return a JSON object
289+
// {"messages": ["...", "..."]} which is decoded directly — no text parsing needed.
219290
//
220-
// It starts a spinner while the request is in progress and ensures
221-
// the spinner is stopped before returning.
291+
// It starts a spinner while the request is in progress.
222292
//
223293
// Requirements:
224294
// - OPENAI_API_KEY environment variable must be set
225-
//
226-
// The returned message is cleaned of formatting artifacts (e.g., code fences).
227-
func generateCommitMessage(ctx GitContext) (string, error) {
295+
func generateCommitMessages(ctx GitContext, count int) ([]string, error) {
228296
apiKey := os.Getenv("OPENAI_API_KEY")
229297
if apiKey == "" {
230-
return "", fmt.Errorf("OPENAI_API_KEY not set")
298+
return nil, fmt.Errorf("OPENAI_API_KEY not set")
231299
}
232300

233-
stop := startSpinner(" Generating commit message...")
301+
spinnerMsg := " Generating commit message..."
302+
if count > 1 {
303+
spinnerMsg = fmt.Sprintf(" Generating %d commit messages...", count)
304+
}
305+
stop := startSpinner(spinnerMsg)
234306

235-
// Build a structured prompt using multiple signals instead of raw diff only.
236-
prompt := fmt.Sprintf(`
237-
You are an expert software engineer that writes precise commit messages.
307+
var (
308+
systemPrompt string
309+
userPrompt string
310+
)
311+
312+
if count == 1 {
313+
systemPrompt = "You write excellent commit messages."
314+
userPrompt = fmt.Sprintf(`You are an expert software engineer that writes precise commit messages.
238315
239316
Follow the Conventional Commits specification.
240317
@@ -251,54 +328,71 @@ Diff:
251328
252329
Task:
253330
Generate ONE properly formatted commit message.
254-
Return ONLY the commit message.
255-
`,
256-
ConventionalCommitRules,
257-
ctx.Branch,
258-
ctx.Status,
259-
ctx.Diff,
260-
)
331+
Return ONLY the commit message, with no preamble or explanation.
332+
`, ConventionalCommitRules, ctx.Branch, ctx.Status, ctx.Diff)
333+
} else {
334+
// JSON mode: the system prompt must mention JSON so the model honours it.
335+
systemPrompt = `You write excellent commit messages. You always respond with valid JSON.`
336+
userPrompt = fmt.Sprintf(`You are an expert software engineer that writes precise commit messages.
337+
338+
Follow the Conventional Commits specification.
339+
340+
%s
341+
342+
Branch:
343+
%s
344+
345+
Git status:
346+
%s
347+
348+
Diff:
349+
%s
350+
351+
Task:
352+
Generate exactly %d distinct, properly formatted commit messages that explore different angles or phrasings of the same change.
353+
Return a JSON object with a single key "messages" whose value is an array of exactly %d strings.
354+
Example format: {"messages": ["feat: add foo", "feat(bar): introduce foo support", ...]}
355+
`, ConventionalCommitRules, ctx.Branch, ctx.Status, ctx.Diff, count, count)
356+
}
261357

262358
body := map[string]any{
263359
"model": "gpt-4o-mini",
264360
"messages": []map[string]string{
265-
{"role": "system", "content": "You write excellent commit messages."},
266-
{"role": "user", "content": prompt},
361+
{"role": "system", "content": systemPrompt},
362+
{"role": "user", "content": userPrompt},
267363
},
268-
"temperature": 0.2,
364+
"temperature": 0.6,
365+
}
366+
if count > 1 {
367+
body["response_format"] = map[string]string{"type": "json_object"}
269368
}
270369

271370
jsonBody, err := json.Marshal(body)
272371
if err != nil {
273372
stop()
274-
return "", err
373+
return nil, err
275374
}
276375

277-
req, err := http.NewRequest(
278-
"POST",
279-
"https://api.openai.com/v1/chat/completions",
280-
bytes.NewBuffer(jsonBody),
281-
)
376+
req, err := http.NewRequest("POST", "https://api.openai.com/v1/chat/completions", bytes.NewBuffer(jsonBody))
282377
if err != nil {
283378
stop()
284-
return "", err
379+
return nil, err
285380
}
286381

287382
req.Header.Set("Authorization", "Bearer "+apiKey)
288383
req.Header.Set("Content-Type", "application/json")
289384

290-
client := &http.Client{}
291-
resp, err := client.Do(req)
385+
resp, err := (&http.Client{}).Do(req)
292386
if err != nil {
293387
stop()
294-
return "", err
388+
return nil, err
295389
}
296390
defer resp.Body.Close()
297391

298392
if resp.StatusCode != http.StatusOK {
299-
respBody, _ := io.ReadAll(resp.Body)
393+
body, _ := io.ReadAll(resp.Body)
300394
stop()
301-
return "", fmt.Errorf("API error: %s", string(respBody))
395+
return nil, fmt.Errorf("API error: %s", string(body))
302396
}
303397

304398
stop()
@@ -313,37 +407,54 @@ Return ONLY the commit message.
313407

314408
var result chatCompletionResponse
315409
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
316-
return "", err
410+
return nil, err
317411
}
318-
319412
if len(result.Choices) == 0 {
320-
return "", fmt.Errorf("no response choices returned")
413+
return nil, fmt.Errorf("no response choices returned")
321414
}
322415

323-
message := result.Choices[0].Message.Content
324-
message = strings.ReplaceAll(message, "```", "")
325-
message = strings.TrimSpace(message)
416+
content := strings.TrimSpace(result.Choices[0].Message.Content)
417+
418+
// Single-message path: return the raw text as-is.
419+
if count == 1 {
420+
content = strings.ReplaceAll(content, "```", "")
421+
return []string{strings.TrimSpace(content)}, nil
422+
}
423+
424+
// Multi-message path: decode the guaranteed JSON object.
425+
var payload struct {
426+
Messages []string `json:"messages"`
427+
}
428+
if err := json.Unmarshal([]byte(content), &payload); err != nil {
429+
return nil, fmt.Errorf("failed to decode JSON response: %w\nraw content: %s", err, content)
430+
}
431+
if len(payload.Messages) == 0 {
432+
return nil, fmt.Errorf("model returned an empty messages array")
433+
}
326434

327-
return message, nil
435+
return payload.Messages, nil
328436
}
329437

438+
// chooseAnOption presents a numbered list of options and prompts the user to
439+
// pick one by number. It loops until a valid choice is entered and returns
440+
// the selected string.
330441
func chooseAnOption(options []string) string {
331442
reader := bufio.NewReader(os.Stdin)
332443
for {
333444
fmt.Println("Proposed commit messages:")
334445
for i, opt := range options {
335-
fmt.Printf("%d) %s\n", i+1, opt)
446+
fmt.Printf(" %d) %s\n", i+1, opt)
336447
}
337448
fmt.Print("> ")
338449
input, _ := reader.ReadString('\n')
339450
input = strings.TrimSpace(input)
340451
choice, err := strconv.Atoi(input)
341452
if err != nil || choice < 1 || choice > len(options) {
342-
fmt.Println("Invalid input, retry.")
453+
fmt.Printf("Invalid input — enter a number between 1 and %d.\n", len(options))
343454
continue
344455
}
345456
selected := options[choice-1]
346-
fmt.Println("Selected message:", selected)
457+
fmt.Println("Selected:", selected)
347458
return selected
348459
}
349460
}
@@ -375,7 +486,7 @@ func createCommit(message string) error {
375486
return fmt.Errorf("commit failed: %s", string(output))
376487
}
377488

378-
fmt.Println("Commit created succesfully")
489+
fmt.Println("Commit created successfully")
379490
fmt.Println(string(output))
380491
return nil
381492
}

0 commit comments

Comments
 (0)