Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions test/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"path/filepath"
"runtime"
"testing"
"time"

"github.com/elementsproject/peerswap/clightning"
"github.com/elementsproject/peerswap/peerswaprpc"
Expand All @@ -25,13 +26,34 @@ const (
)

// makeTestDataDir creates a temporary directory for test data with proper cleanup.
// It uses os.MkdirTemp() instead of t.TempDir() to avoid problems with long unix
// It uses os.MkdirTemp() instead of t.TempDir() to avoid problems with long unix
// socket paths. See https://github.com/golang/go/issues/62614.
func makeTestDataDir(t *testing.T) string {
tempDir, err := os.MkdirTemp("", "cln-test-")
require.NoError(t, err, "os.MkdirTemp failed")
t.Cleanup(func() { os.RemoveAll(tempDir) })
return tempDir
// 1. Check for custom test directory from environment
if baseDir := os.Getenv("PEERSWAP_TEST_DIR"); baseDir != "" {
testDir := filepath.Join(baseDir, fmt.Sprintf("t%d", time.Now().UnixNano()))
err := os.MkdirAll(testDir, 0755)
require.NoError(t, err, "failed to create test dir in PEERSWAP_TEST_DIR")
t.Cleanup(func() { os.RemoveAll(testDir) })
return testDir
}

// 2. Try to use /tmp/ps/ for shorter paths
shortBase := "/tmp/ps"
if err := os.MkdirAll(shortBase, 0755); err == nil {
// Use process ID and timestamp for uniqueness
testDir := filepath.Join(shortBase, fmt.Sprintf("%d-%d", os.Getpid(), time.Now().UnixNano()%1000000))
if err := os.MkdirAll(testDir, 0755); err == nil {
t.Cleanup(func() { os.RemoveAll(testDir) })
return testDir
}
}

// 3. Fallback to standard temp directory with short prefix
tempDir, err := os.MkdirTemp("", "ps-")
require.NoError(t, err, "os.MkdirTemp failed")
t.Cleanup(func() { os.RemoveAll(tempDir) })
return tempDir
}

type fundingNode string
Expand Down
63 changes: 40 additions & 23 deletions testframework/clightning.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,8 @@ func NewCLightningNode(testDir string, bitcoin *BitcoinNode, id int) (*CLightnin
return nil, fmt.Errorf("GetFreePort() %w", err)
}

rngDirExtension, err := GenerateRandomString(5)
if err != nil {
return nil, fmt.Errorf("GenerateRandomString(5) %w", err)
}

dataDir := filepath.Join(testDir, fmt.Sprintf("clightning-%s", rngDirExtension))
// Use node ID for directory name instead of random string (shorter and more predictable)
dataDir := filepath.Join(testDir, fmt.Sprintf("c%d", id))
networkDir := filepath.Join(dataDir, "regtest")

err = os.MkdirAll(networkDir, os.ModeDir|os.ModePerm)
Expand Down Expand Up @@ -87,28 +83,33 @@ func NewCLightningNode(testDir string, bitcoin *BitcoinNode, id int) (*CLightnin
fmt.Sprintf("--allow-deprecated-apis=true"),
}

// socketPath := filepath.Join(networkDir, "lightning-rpc")
// Check socket path length before proceeding
socketPath := filepath.Join(networkDir, "lightning-rpc")
if len(socketPath) > 104 { // Unix domain socket path limit
return nil, fmt.Errorf("socket path too long (%d chars): %s. Unix domain sockets are limited to 104-108 characters. Try setting TMPDIR to a shorter path.", len(socketPath), socketPath)
}

proxy, err := NewCLightningProxy("lightning-rpc", networkDir)
if err != nil {
return nil, fmt.Errorf("NewCLightningProxy() %w", err)
}

// Create seed file
regex, _ := regexp.Compile("[^/]+")
found := regex.FindAll([]byte(dataDir), -1)
all := []byte{}
for _, v := range found {
all = append(all, v...)
// Create seed file with a deterministic but unique seed
// Use dataDir path and node ID to generate a 32-byte seed
seedSource := fmt.Sprintf("%s-node-%d-seed-padding", dataDir, id)
// Ensure we have at least 32 bytes
for len(seedSource) < 32 {
seedSource += "0"
}
seed := regex.Find(all)[len(all)-32:]
seed := []byte(seedSource)[:32]
seedFile := filepath.Join(networkDir, "hsm_secret")
err = os.WriteFile(seedFile, seed, os.ModePerm)
if err != nil {
return nil, fmt.Errorf("WriteFile() %w", err)
}

return &CLightningNode{
DaemonProcess: NewDaemonProcess(cmdLine, fmt.Sprintf("clightning-%d", id)),
DaemonProcess: NewDaemonProcess(cmdLine, fmt.Sprintf("cln-%d", id)),
CLightningProxy: proxy,
DataDir: dataDir,
Port: port,
Expand All @@ -118,18 +119,13 @@ func NewCLightningNode(testDir string, bitcoin *BitcoinNode, id int) (*CLightnin

func (n *CLightningNode) Run(waitForReady, waitForBitcoinSynced bool) error {
n.DaemonProcess.Run()
if waitForReady {
err := n.WaitForLog("Server started with public key", TIMEOUT)
if err != nil {
return fmt.Errorf("CLightningNode.Run() %w", err)
}
}

// Establish RPC connection first
var counter int
var err error
for {
if counter > 10 {
return fmt.Errorf("to many retries: %w", err)
if counter > 20 {
return fmt.Errorf("too many retries establishing RPC connection: %w", err)
}

err = n.StartProxy()
Expand All @@ -142,6 +138,27 @@ func (n *CLightningNode) Run(waitForReady, waitForBitcoinSynced bool) error {
break
}

if waitForReady {
// Wait for CLN to be ready via RPC
err = WaitFor(func() bool {
info, err := n.Rpc.GetInfo()
if err != nil {
// RPC might not be fully ready yet
return false
}
// CLN is ready when it has a valid node ID
if info.Id != "" {
log.Printf("CLN node ready with ID: %s", info.Id)
return true
}
return false
}, TIMEOUT)

if err != nil {
return fmt.Errorf("CLightningNode.Run() startup detection failed: %w", err)
}
}

// Cache info
n.Info, err = n.Rpc.GetInfo()
if err != nil {
Expand Down
18 changes: 17 additions & 1 deletion testframework/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,23 @@ func (d *DaemonProcess) WaitForLog(regex string, timeout time.Duration) error {
for {
select {
case <-timer.C:
return fmt.Errorf("timeout reached while waiting for `%s` in logs", regex)
lastLogs := d.StdOut.Tail(20, ".*")

stderrContent := d.StdErr.String()
errMsg := fmt.Sprintf("timeout reached while waiting for `%s` in logs", regex)

if lastLogs != "" {
errMsg += fmt.Sprintf("\n\n=== Last 20 lines of stdout ===\n%s", lastLogs)
}

if stderrContent != "" {
stderrTail := d.StdErr.Tail(10, ".*")
if stderrTail != "" {
errMsg += fmt.Sprintf("\n\n=== Last 10 lines of stderr ===\n%s", stderrTail)
}
}

return fmt.Errorf(errMsg)
default:
ok, err := d.HasLog(regex)
if err != nil {
Expand Down
Loading