diff --git a/client/client.go b/client/client.go index 391bea0..27c4f85 100644 --- a/client/client.go +++ b/client/client.go @@ -32,10 +32,10 @@ func MakeCertRequest() CertRequest { type SigningRequest struct { signedCert ssh.Certificate requestID string - config ssh_ca_util.SignerConfig + config ssh_ca_util.RequesterConfig } -func MakeSigningRequest(cert ssh.Certificate, requestID string, config ssh_ca_util.SignerConfig) SigningRequest { +func MakeSigningRequest(cert ssh.Certificate, requestID string, config ssh_ca_util.RequesterConfig) SigningRequest { var request SigningRequest request.signedCert = cert request.requestID = requestID diff --git a/get_cert.go b/get_cert.go index 3a51cac..f0e74ce 100644 --- a/get_cert.go +++ b/get_cert.go @@ -32,7 +32,7 @@ func getCertFlags() []cli.Flag { Value: configPath, Usage: "Path to config.json", }, - cli.BoolTFlag{ + cli.BoolFlag{ Name: "add-key", Usage: "When set automatically call ssh-add", }, @@ -65,7 +65,7 @@ func getCert(c *cli.Context) error { if err != nil { return cli.NewExitError(fmt.Sprintf("%s", err), 1) } - if c.BoolT("add-key") { + if c.Bool("add-key") { err = addCertToAgent(cert, sshDir) if err != nil { return cli.NewExitError(fmt.Sprintf("%s", err), 1) @@ -94,6 +94,9 @@ func addCertToAgent(cert *ssh.Certificate, sshDir string) error { } func downloadCert(config ssh_ca_util.RequesterConfig, certRequestID string, sshDir string) (*ssh.Certificate, error) { + ssh_ca_util.StartTunnelIfNeeded(&config) + //fmt.Printf("get_cert downloadCert using signer url: %s", config.SignerUrl) + getResp, err := http.Get(config.SignerUrl + "cert/requests/" + certRequestID) if err != nil { return nil, fmt.Errorf("Didn't get a valid response: %s", err) @@ -119,6 +122,7 @@ func downloadCert(config ssh_ca_util.RequesterConfig, certRequestID string, sshD return nil, err } pubKeyPath = strings.Replace(pubKeyPath, ".pub", "-cert.pub", 1) + fmt.Printf("%s\n", getRespBuf) err = ioutil.WriteFile(pubKeyPath, getRespBuf, 0644) if err != nil { fmt.Printf("Couldn't write certificate file to %s: %s\n", pubKeyPath, err) diff --git a/list_requests.go b/list_requests.go index 5d6300d..590535d 100644 --- a/list_requests.go +++ b/list_requests.go @@ -18,7 +18,7 @@ func listCertFlags() []cli.Flag { if home == "" { home = "/" } - configPath := home + "/.ssh_ca/signer_config.json" + configPath := home + "/.ssh_ca/requester_config.json" return []cli.Flag{ cli.StringFlag{ @@ -51,6 +51,8 @@ func listCerts(c *cli.Context) error { return cli.NewExitError(fmt.Sprintf("%s", err), 1) } config := wrongTypeConfig.(ssh_ca_util.RequesterConfig) + + ssh_ca_util.StartTunnelIfNeeded(&config) getResp, err := http.Get(config.SignerUrl + "cert/requests") if err != nil { diff --git a/request_cert.go b/request_cert.go index dac0852..b781549 100644 --- a/request_cert.go +++ b/request_cert.go @@ -67,9 +67,9 @@ func requestCertFlags() []cli.Flag { Name: "quiet", Usage: "Print only the request id on success", }, - cli.BoolTFlag{ - Name: "add-key", - Usage: "When set automatically call ssh-add if cert was auto-signed by server", + cli.BoolFlag{ + Name: "no-get-key", + Usage: "When set don't automatically download the key", }, cli.StringFlag{ Name: "ssh-dir", @@ -94,6 +94,8 @@ func requestCert(c *cli.Context) error { return cli.NewExitError(fmt.Sprintf("%s", err), 1) } config := wrongTypeConfig.(ssh_ca_util.RequesterConfig) + + ssh_ca_util.StartTunnelIfNeeded(&config) reason := c.String("reason") if reason == "" { @@ -176,15 +178,12 @@ func requestCert(c *cli.Context) error { appendage = " auto-signed" } fmt.Printf("Cert request id: %s%s\n", requestID, appendage) - if signed && c.BoolT("add-key") { - cert, err := downloadCert(config, requestID, sshDir) - if err != nil { - return cli.NewExitError(fmt.Sprintf("%s", err), 1) - } - err = addCertToAgent(cert, sshDir) + if signed && !c.Bool("no-get-key") { + _, err := downloadCert(config, requestID, sshDir) if err != nil { return cli.NewExitError(fmt.Sprintf("%s", err), 1) } + // add cert to agent didn't seem to work and seemed unnecessary } } } else { diff --git a/sign_cert.go b/sign_cert.go index e88887b..b9039c3 100644 --- a/sign_cert.go +++ b/sign_cert.go @@ -30,7 +30,7 @@ func signCertFlags() []cli.Flag { if home == "" { home = "/" } - configPath := home + "/.ssh_ca/signer_config.json" + configPath := home + "/.ssh_ca/requester_config.json" return []cli.Flag{ cli.StringFlag{ @@ -53,7 +53,7 @@ func signCertFlags() []cli.Flag { func signCert(c *cli.Context) error { configPath := c.String("config-file") - allConfig := make(map[string]ssh_ca_util.SignerConfig) + allConfig := make(map[string]ssh_ca_util.RequesterConfig) err := ssh_ca_util.LoadConfig(configPath, &allConfig) if err != nil { return cli.NewExitError(fmt.Sprintf("Load Config failed: %s", err), 1) @@ -71,7 +71,9 @@ func signCert(c *cli.Context) error { if err != nil { return cli.NewExitError(fmt.Sprintf("%s", err), 1) } - config := wrongTypeConfig.(ssh_ca_util.SignerConfig) + config := wrongTypeConfig.(ssh_ca_util.RequesterConfig) + + ssh_ca_util.StartTunnelIfNeeded(&config) conn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) if err != nil { diff --git a/util/config.go b/util/config.go index 46da76a..e2d119a 100644 --- a/util/config.go +++ b/util/config.go @@ -10,6 +10,8 @@ type RequesterConfig struct { PublicKeyPath string `json:",omitempty"` PublicKeyFingerprint string `json:",omitempty"` SignerUrl string + SshBastion string `json:",omitempty"` + KeyFingerprint string `json:",omitempty"` } type SignerdConfig struct { @@ -25,11 +27,6 @@ type SignerdConfig struct { CriticalOptions map[string]string } -type SignerConfig struct { - KeyFingerprint string - SignerUrl string -} - func LoadConfig(configPath string, environmentConfigs interface{}) error { buf, err := ioutil.ReadFile(configPath) if err != nil { @@ -37,7 +34,7 @@ func LoadConfig(configPath string, environmentConfigs interface{}) error { } switch configType := environmentConfigs.(type) { - case *map[string]RequesterConfig, *map[string]SignerConfig, *map[string]SignerdConfig: + case *map[string]RequesterConfig, *map[string]SignerdConfig: return json.Unmarshal(buf, &environmentConfigs) default: return fmt.Errorf("oops: %T\n", configType) @@ -56,24 +53,10 @@ func GetConfigForEnv(environment string, environmentConfigs interface{}) (interf // lame way of extracting first and only key from a map? } } + config, ok := configs[environment] if !ok { - return nil, fmt.Errorf("Requested environment not found in config file.") - } - return config, nil - case *map[string]SignerConfig: - configs := *environmentConfigs.(*map[string]SignerConfig) - if len(configs) > 1 && environment == "" { - return nil, fmt.Errorf("You must tell me which environment to use.") - } - if len(configs) == 1 && environment == "" { - for environment = range configs { - // lame way of extracting first and only key from a map? - } - } - config, ok := configs[environment] - if !ok { - return nil, fmt.Errorf("Requested environment not found in config file.") + return nil, fmt.Errorf("Requested environment not found in config file1.") } return config, nil } diff --git a/util/sshtunnel.go b/util/sshtunnel.go new file mode 100644 index 0000000..5ae6446 --- /dev/null +++ b/util/sshtunnel.go @@ -0,0 +1,190 @@ +// copied from https://gist.github.com/svett/5d695dcc4cc6ad5dd275 + +package ssh_ca_util + +import ( +// "log" +// "bufio" +// "time" + "os" + "fmt" + "io" + "strings" + "strconv" + "net" + "sync" + "net/url" + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" +) + +type Endpoint struct { + Host string + Port int +} + +func (endpoint *Endpoint) String() string { + return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port) +} + +type SSHtunnel struct { + Local *Endpoint + Server *Endpoint + Remote *Endpoint + + Config *ssh.ClientConfig +} + +func (tunnel *SSHtunnel) Start(out_port chan int) error { + listener, err := net.Listen("tcp", tunnel.Local.String()) + if err != nil { + return err + } + //fmt.Println("Using local port:", listener.Addr().(*net.TCPAddr).Port) + //tunnel.LocalPort = listener.Addr().(*net.TCPAddr).Port + out_port <- listener.Addr().(*net.TCPAddr).Port + defer listener.Close() + + for { + conn, err := listener.Accept() + if err != nil { + return err + } + go tunnel.forward(conn) + } +} + +func (tunnel *SSHtunnel) forward(localConn net.Conn) { + serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config) + if err != nil { + fmt.Printf("Server dial error: %s\n", err) + return + } + + remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String()) + if err != nil { + fmt.Printf("Remote dial error: %s\n", err) + return + } + + copyConn:=func(writer, reader net.Conn) { + _, err:= io.Copy(writer, reader) + if err != nil { + fmt.Printf("io.Copy error: %s", err) + } + } + + go copyConn(localConn, remoteConn) + go copyConn(remoteConn, localConn) +} + +func SSHAgent() ssh.AuthMethod { + if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { + return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers) + } + return nil +} + +var ( + jobIsRunning bool + JobIsrunningMu sync.Mutex +) + +func StartTunnelIfNeeded(config *RequesterConfig) { + if len(config.SshBastion) > 0 { + + JobIsrunningMu.Lock() + start := !jobIsRunning + jobIsRunning = true + JobIsrunningMu.Unlock() + if start { + if !strings.HasPrefix(config.SshBastion, "ssh://") { + fmt.Printf("Bastion host must start with ssh://. Exiting\n") + os.Exit(1) + } + + bastion_parsed, err := url.Parse(config.SshBastion) + if err != nil { + fmt.Printf("url.Parse error for SshBastion: %s", err) + } + + // Check to see if it's a nonstardard port + host_parts := strings.Split(bastion_parsed.Host, ":") + var ssh_port int + ssh_port = 22 + if len(host_parts) == 2 { + var err error + ssh_port, err = strconv.Atoi(host_parts[1]) + if err != nil { + fmt.Printf("strconv.Atoi error: %s", err) + } + } + + // Get remote end information + remote_parsed, err := url.Parse(config.SignerUrl) + if err != nil { + fmt.Printf("url.Parse error on SignerUrl: %s", err) + } + remote_parts := strings.Split(remote_parsed.Host, ":") + if len(remote_parts) != 2 { + fmt.Printf("Missing port for SignerUrl. Exiting") + os.Exit(1) + } + remote_port, err := strconv.Atoi(remote_parts[1]) + if err != nil { + fmt.Printf("strconv.Atoi error: %s", err) + } + + //fmt.Printf("config stuff: %s, %d\n", host_parts[0], ssh_port) + //fmt.Printf("starting tunnel config...\n") + localEndpoint := &Endpoint{ + Host: "localhost", + Port: 0, + } + + serverEndpoint := &Endpoint{ + Host: host_parts[0], + Port: ssh_port, + } + + remoteEndpoint := &Endpoint{ + Host: remote_parts[0], + Port: remote_port, + } + + sshConfig := &ssh.ClientConfig{ + User: bastion_parsed.User.Username(), + Auth: []ssh.AuthMethod{ + SSHAgent(), + }, + // TODO: fix this to actually check the trusted hosts + // https://utcc.utoronto.ca/~cks/space/blog/programming/GoSSHHostKeyCheckingNotes + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + + tunnel := &SSHtunnel{ + Config: sshConfig, + Local: localEndpoint, + Server: serverEndpoint, + Remote: remoteEndpoint, + } + + //fmt.Printf("starting tunnel...\n") + + out_port_chan := make(chan int) + go tunnel.Start(out_port_chan) + var local_port int + local_port = <- out_port_chan + //fmt.Printf("Using local port: %d\n", local_port) + + //fmt.Printf("doing normal stuff...\n") + + config.SignerUrl = fmt.Sprintf("%s://localhost:%d/", remote_parsed.Scheme, local_port) + //fmt.Printf("sshtunnel using signer url: %s", config.SignerUrl) + // end new stuff + } + } +} +