Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions internal/libvirt/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/hex"
"fmt"
"strings"
"net/url"

"github.com/digitalocean/go-libvirt"
"github.com/dmacvicar/terraform-provider-libvirt/v2/internal/libvirt/dialers"
Expand All @@ -20,12 +21,17 @@ type Client struct {

// NewClient creates a new libvirt client from a connection URI
func NewClient(ctx context.Context, uri string) (*Client, error) {
parsedURI, err := url.Parse(uri)
if err != nil {
return nil, fmt.Errorf("invalid libvirt URI: %w", err)
}

tflog.Debug(ctx, "Creating new libvirt client", map[string]any{
"uri": uri,
})

// Create the appropriate dialer based on the URI
dialer, err := dialers.NewDialerFromURI(uri)
dialer, err := dialers.NewDialerFromURI(parsedURI)
if err != nil {
return nil, fmt.Errorf("failed to create dialer: %w", err)
}
Expand All @@ -40,10 +46,16 @@ func NewClient(ctx context.Context, uri string) (*Client, error) {
return nil, fmt.Errorf("failed to dial libvirt: %w", err)
}

internalURI := url.URL{
Path: parsedURI.Path,
Scheme: strings.Split(parsedURI.Scheme, "+")[0],
}
tflog.Debug(ctx, "", map[string]any{"internalURI": internalURI.String()})

// Create libvirt client
//nolint:staticcheck // NewWithDialer is too complex for our use case
l := libvirt.New(conn)
if err := l.Connect(); err != nil {
if err := l.ConnectToURI(libvirt.ConnectURI(internalURI.String())); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("failed to connect to libvirt: %w", err)
}
Expand Down
23 changes: 9 additions & 14 deletions internal/libvirt/dialers/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,11 @@ const (
// NewDialerFromURI creates the appropriate Dialer based on the libvirt URI.
// It uses upstream go-libvirt dialers for most transports and custom dialers
// for special cases like SSHCmd.
func NewDialerFromURI(uriStr string) (Dialer, error) {
parsedURI, err := url.Parse(uriStr)
if err != nil {
return nil, fmt.Errorf("invalid libvirt URI: %w", err)
}

func NewDialerFromURI(uri *url.URL) (Dialer, error) {
// Parse the scheme to extract driver and transport
// Format: driver[+transport]://[host]/path
// Examples: qemu:///system, qemu+ssh://host/system, qemu+sshcmd://host/system
schemeParts := strings.Split(parsedURI.Scheme, "+")
schemeParts := strings.Split(uri.Scheme, "+")
driver := schemeParts[0]
transport := ""
if len(schemeParts) > 1 {
Expand All @@ -43,27 +38,27 @@ func NewDialerFromURI(uriStr string) (Dialer, error) {
}

// Local connection (no transport specified and no host)
if transport == "" && parsedURI.Host == "" {
return newLocalDialer(parsedURI)
if transport == "" && uri.Host == "" {
return newLocalDialer(uri)
}

// Remote connections
switch transport {
case "ssh":
// Use Go SSH library (upstream dialer)
return newGoSSHDialer(parsedURI)
return newGoSSHDialer(uri)
case "sshcmd":
// Use native SSH command (custom dialer)
return NewSSHCmd(parsedURI), nil
return NewSSHCmd(uri), nil
case "tcp":
// Plain TCP connection (upstream dialer)
return newRemoteDialer(parsedURI)
return newRemoteDialer(uri)
case "tls":
// TLS connection (upstream dialer)
return newTLSDialer(parsedURI)
return newTLSDialer(uri)
case "":
// No transport but has host - assume SSH
return newGoSSHDialer(parsedURI)
return newGoSSHDialer(uri)
default:
return nil, fmt.Errorf("unsupported transport: %s", transport)
}
Expand Down