Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ func (m *home) handleKeyPress(msg tea.KeyMsg) (mod tea.Model, cmd tea.Cmd) {
Title: "",
Path: ".",
Program: m.program,
BaseRef: "HEAD",
})
if err != nil {
return m, m.handleError(err)
Expand All @@ -496,6 +497,7 @@ func (m *home) handleKeyPress(msg tea.KeyMsg) (mod tea.Model, cmd tea.Cmd) {
Title: "",
Path: ".",
Program: m.program,
BaseRef: "HEAD",
})
if err != nil {
return m, m.handleError(err)
Expand All @@ -506,6 +508,47 @@ func (m *home) handleKeyPress(msg tea.KeyMsg) (mod tea.Model, cmd tea.Cmd) {
m.state = stateNew
m.menu.SetState(ui.StateNewInstance)

return m, nil
case keys.KeyNewFromMain:
if m.list.NumInstances() >= GlobalInstanceLimit {
return m, m.handleError(
fmt.Errorf("you can't create more than %d instances", GlobalInstanceLimit))
}
instance, err := session.NewInstance(session.InstanceOptions{
Title: "",
Path: ".",
Program: m.program,
BaseRef: "main",
})
if err != nil {
return m, m.handleError(err)
}

m.newInstanceFinalizer = m.list.AddInstance(instance)
m.list.SetSelectedInstance(m.list.NumInstances() - 1)
m.state = stateNew
m.menu.SetState(ui.StateNewInstance)
return m, nil
case keys.KeyPromptFromMain:
if m.list.NumInstances() >= GlobalInstanceLimit {
return m, m.handleError(
fmt.Errorf("you can't create more than %d instances", GlobalInstanceLimit))
}
instance, err := session.NewInstance(session.InstanceOptions{
Title: "",
Path: ".",
Program: m.program,
BaseRef: "main",
})
if err != nil {
return m, m.handleError(err)
}

m.newInstanceFinalizer = m.list.AddInstance(instance)
m.list.SetSelectedInstance(m.list.NumInstances() - 1)
m.state = stateNew
m.menu.SetState(ui.StateNewInstance)
m.promptAfterName = true
return m, nil
case keys.KeyUp:
m.list.Up()
Expand Down
6 changes: 4 additions & 2 deletions app/help.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ func (h helpTypeGeneral) toContent() string {
"A terminal UI that manages multiple Claude Code (and other local agents) in separate workspaces.",
"",
headerStyle.Render("Managing:"),
keyStyle.Render("n")+descStyle.Render(" - Create a new session"),
keyStyle.Render("N")+descStyle.Render(" - Create a new session with a prompt"),
keyStyle.Render("n")+descStyle.Render(" - Create a new session (from current HEAD)"),
keyStyle.Render("m")+descStyle.Render(" - Create a new session from main/master branch"),
keyStyle.Render("N")+descStyle.Render(" - Create a new session with a prompt (from HEAD)"),
keyStyle.Render("M")+descStyle.Render(" - Create a new session with a prompt (from main)"),
keyStyle.Render("D")+descStyle.Render(" - Kill (delete) the selected session"),
keyStyle.Render("↑/j, ↓/k")+descStyle.Render(" - Navigate between sessions"),
keyStyle.Render("↵/o")+descStyle.Render(" - Attach to the selected session"),
Expand Down
3 changes: 3 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ type Config struct {
DaemonPollInterval int `json:"daemon_poll_interval"`
// BranchPrefix is the prefix used for git branches created by the application.
BranchPrefix string `json:"branch_prefix"`
// DefaultBranch is the name of the default branch (e.g., "main", "master"). If empty, auto-detects.
DefaultBranch string `json:"default_branch,omitempty"`
}

// DefaultConfig returns the default configuration
Expand All @@ -58,6 +60,7 @@ func DefaultConfig() *Config {
}
return fmt.Sprintf("%s/", strings.ToLower(user.Username))
}(),
DefaultBranch: "", // Empty means auto-detect
}
}

Expand Down
14 changes: 14 additions & 0 deletions keys/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ const (
KeyPrompt // New key for entering a prompt
KeyHelp // Key for showing help screen

// Main branch creation keys
KeyNewFromMain // New key for creating instance from main branch
KeyPromptFromMain // New key for creating instance with prompt from main branch

// Diff keybindings
KeyShiftUp
KeyShiftDown
Expand All @@ -39,9 +43,11 @@ var GlobalKeyStringsMap = map[string]KeyName{
"shift+up": KeyShiftUp,
"shift+down": KeyShiftDown,
"N": KeyPrompt,
"M": KeyPromptFromMain,
"enter": KeyEnter,
"o": KeyEnter,
"n": KeyNew,
"m": KeyNewFromMain,
"D": KeyKill,
"q": KeyQuit,
"tab": KeyTab,
Expand Down Expand Up @@ -97,6 +103,14 @@ var GlobalkeyBindings = map[KeyName]key.Binding{
key.WithKeys("N"),
key.WithHelp("N", "new with prompt"),
),
KeyNewFromMain: key.NewBinding(
key.WithKeys("m"),
key.WithHelp("m", "new from main"),
),
KeyPromptFromMain: key.NewBinding(
key.WithKeys("M"),
key.WithHelp("M", "new from main with prompt"),
),
KeyCheckout: key.NewBinding(
key.WithKeys("c"),
key.WithHelp("c", "checkout"),
Expand Down
7 changes: 7 additions & 0 deletions session/git/worktree.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ type GitWorktree struct {
branchName string
// Base commit hash for the worktree
baseCommitSHA string
// BaseRef indicates whether to use main branch or current HEAD
baseRef string
}

func NewGitWorktreeFromStorage(repoPath string, worktreePath string, sessionName string, branchName string, baseCommitSHA string) *GitWorktree {
Expand Down Expand Up @@ -100,3 +102,8 @@ func (g *GitWorktree) GetRepoName() string {
func (g *GitWorktree) GetBaseCommitSHA() string {
return g.baseCommitSHA
}

// SetBaseRef sets the base reference for the worktree (e.g., "main" or "HEAD")
func (g *GitWorktree) SetBaseRef(baseRef string) {
g.baseRef = baseRef
}
45 changes: 45 additions & 0 deletions session/git/worktree_git.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package git

import (
"claude-squad/config"
"claude-squad/log"
"fmt"
"os/exec"
Expand Down Expand Up @@ -135,3 +136,47 @@ func (g *GitWorktree) OpenBranchURL() error {
}
return nil
}

// GetDefaultBranch detects the default branch of the repository
// It tries in order: config override, origin/HEAD, origin/main, origin/master
func GetDefaultBranch(repoPath string) (string, error) {
// Check config for override
cfg := config.LoadConfig()
if cfg.DefaultBranch != "" {
return cfg.DefaultBranch, nil
}

// Create a temporary GitWorktree just for running commands
g := &GitWorktree{repoPath: repoPath}

// Try to get the default branch from origin/HEAD
output, err := g.runGitCommand(repoPath, "symbolic-ref", "refs/remotes/origin/HEAD")
if err == nil {
// Output will be something like "refs/remotes/origin/main"
branch := strings.TrimSpace(string(output))
branch = strings.TrimPrefix(branch, "refs/remotes/origin/")
if branch != "" {
return branch, nil
}
}

// Try common default branch names
branches := []string{"main", "master"}
for _, branch := range branches {
// Check if the remote branch exists
_, err := g.runGitCommand(repoPath, "rev-parse", "--verify", fmt.Sprintf("origin/%s", branch))
if err == nil {
return branch, nil
}
}

// If no remote branches found, try local branches
for _, branch := range branches {
_, err := g.runGitCommand(repoPath, "rev-parse", "--verify", branch)
if err == nil {
return branch, nil
}
}

return "", fmt.Errorf("could not determine default branch")
}
93 changes: 93 additions & 0 deletions session/git/worktree_git_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package git

import (
"os"
"os/exec"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGetDefaultBranch(t *testing.T) {
// Create a temporary directory for test repos
tempDir := t.TempDir()

t.Run("detects main branch", func(t *testing.T) {
repoPath := filepath.Join(tempDir, "repo-with-main")
setupTestRepo(t, repoPath, "main")

branch, err := GetDefaultBranch(repoPath)
assert.NoError(t, err)
assert.Equal(t, "main", branch)
})

t.Run("detects master branch", func(t *testing.T) {
repoPath := filepath.Join(tempDir, "repo-with-master")
setupTestRepo(t, repoPath, "master")

branch, err := GetDefaultBranch(repoPath)
assert.NoError(t, err)
assert.Equal(t, "master", branch)
})

t.Run("prefers main over master", func(t *testing.T) {
repoPath := filepath.Join(tempDir, "repo-with-both")
setupTestRepo(t, repoPath, "main")
// Also create a master branch
cmd := exec.Command("git", "checkout", "-b", "master")
cmd.Dir = repoPath
require.NoError(t, cmd.Run())
cmd = exec.Command("git", "checkout", "main")
cmd.Dir = repoPath
require.NoError(t, cmd.Run())

branch, err := GetDefaultBranch(repoPath)
assert.NoError(t, err)
assert.Equal(t, "main", branch)
})

t.Run("handles missing default branch", func(t *testing.T) {
repoPath := filepath.Join(tempDir, "repo-with-custom")
setupTestRepo(t, repoPath, "develop")

_, err := GetDefaultBranch(repoPath)
assert.Error(t, err)
assert.Contains(t, err.Error(), "could not determine default branch")
})
}

// setupTestRepo creates a test git repository with the specified default branch
func setupTestRepo(t *testing.T, repoPath string, defaultBranch string) {
t.Helper()

// Create directory
require.NoError(t, os.MkdirAll(repoPath, 0755))

// Initialize repo with explicit initial branch
cmd := exec.Command("git", "init", "-b", defaultBranch)
cmd.Dir = repoPath
require.NoError(t, cmd.Run())

// Set user for commits
cmd = exec.Command("git", "config", "user.email", "[email protected]")
cmd.Dir = repoPath
require.NoError(t, cmd.Run())

cmd = exec.Command("git", "config", "user.name", "Test User")
cmd.Dir = repoPath
require.NoError(t, cmd.Run())

// Create initial commit
testFile := filepath.Join(repoPath, "README.md")
require.NoError(t, os.WriteFile(testFile, []byte("# Test Repo"), 0644))

cmd = exec.Command("git", "add", ".")
cmd.Dir = repoPath
require.NoError(t, cmd.Run())

cmd = exec.Command("git", "commit", "-m", "Initial commit")
cmd.Dir = repoPath
require.NoError(t, cmd.Run())
}
51 changes: 37 additions & 14 deletions session/git/worktree_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (g *GitWorktree) setupFromExistingBranch() error {
return nil
}

// setupNewWorktree creates a new worktree from HEAD
// setupNewWorktree creates a new worktree from HEAD or main branch
func (g *GitWorktree) setupNewWorktree() error {
// Ensure worktrees directory exists
worktreesDir := filepath.Join(g.repoPath, "worktrees")
Expand All @@ -94,24 +94,47 @@ func (g *GitWorktree) setupNewWorktree() error {
return fmt.Errorf("failed to cleanup existing branch: %w", err)
}

output, err := g.runGitCommand(g.repoPath, "rev-parse", "HEAD")
if err != nil {
if strings.Contains(err.Error(), "fatal: ambiguous argument 'HEAD'") ||
strings.Contains(err.Error(), "fatal: not a valid object name") ||
strings.Contains(err.Error(), "fatal: HEAD: not a valid object name") {
return fmt.Errorf("this appears to be a brand new repository: please create an initial commit before creating an instance")
var targetCommit string

// Determine which commit to use based on baseRef
if g.baseRef == "main" {
// Get the default branch
defaultBranch, err := GetDefaultBranch(g.repoPath)
if err != nil {
return fmt.Errorf("failed to determine default branch: %w", err)
}
return fmt.Errorf("failed to get HEAD commit hash: %w", err)

// Get the commit hash for the default branch
output, err := g.runGitCommand(g.repoPath, "rev-parse", defaultBranch)
if err != nil {
// Try with origin prefix
output, err = g.runGitCommand(g.repoPath, "rev-parse", fmt.Sprintf("origin/%s", defaultBranch))
if err != nil {
return fmt.Errorf("failed to get commit hash for branch %s: %w", defaultBranch, err)
}
}
targetCommit = strings.TrimSpace(string(output))
} else {
// Use current HEAD (existing behavior)
output, err := g.runGitCommand(g.repoPath, "rev-parse", "HEAD")
if err != nil {
if strings.Contains(err.Error(), "fatal: ambiguous argument 'HEAD'") ||
strings.Contains(err.Error(), "fatal: not a valid object name") ||
strings.Contains(err.Error(), "fatal: HEAD: not a valid object name") {
return fmt.Errorf("this appears to be a brand new repository: please create an initial commit before creating an instance")
}
return fmt.Errorf("failed to get HEAD commit hash: %w", err)
}
targetCommit = strings.TrimSpace(string(output))
}
headCommit := strings.TrimSpace(string(output))
g.baseCommitSHA = headCommit

// Create a new worktree from the HEAD commit
g.baseCommitSHA = targetCommit

// Create a new worktree from the target commit
// Otherwise, we'll inherit uncommitted changes from the previous worktree.
// This way, we can start the worktree with a clean slate.
// TODO: we might want to give an option to use main/master instead of the current branch.
if _, err := g.runGitCommand(g.repoPath, "worktree", "add", "-b", g.branchName, g.worktreePath, headCommit); err != nil {
return fmt.Errorf("failed to create worktree from commit %s: %w", headCommit, err)
if _, err := g.runGitCommand(g.repoPath, "worktree", "add", "-b", g.branchName, g.worktreePath, targetCommit); err != nil {
return fmt.Errorf("failed to create worktree from commit %s: %w", targetCommit, err)
}

return nil
Expand Down
Loading
Loading