Skip to content

Add bastion host capability #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func MakeCertRequest() CertRequest {
type SigningRequest struct {
signedCert ssh.Certificate
requestID string
config ssh_ca_util.SignerConfig
config ssh_ca_util.RequesterConfig
}

func MakeSigningRequest(cert ssh.Certificate, requestID string, config ssh_ca_util.SignerConfig) SigningRequest {
func MakeSigningRequest(cert ssh.Certificate, requestID string, config ssh_ca_util.RequesterConfig) SigningRequest {
var request SigningRequest
request.signedCert = cert
request.requestID = requestID
Expand Down
8 changes: 6 additions & 2 deletions get_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func getCertFlags() []cli.Flag {
Value: configPath,
Usage: "Path to config.json",
},
cli.BoolTFlag{
cli.BoolFlag{
Name: "add-key",
Usage: "When set automatically call ssh-add",
},
Expand Down Expand Up @@ -65,7 +65,7 @@ func getCert(c *cli.Context) error {
if err != nil {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
}
if c.BoolT("add-key") {
if c.Bool("add-key") {
err = addCertToAgent(cert, sshDir)
if err != nil {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
Expand Down Expand Up @@ -94,6 +94,9 @@ func addCertToAgent(cert *ssh.Certificate, sshDir string) error {
}

func downloadCert(config ssh_ca_util.RequesterConfig, certRequestID string, sshDir string) (*ssh.Certificate, error) {
ssh_ca_util.StartTunnelIfNeeded(&config)
//fmt.Printf("get_cert downloadCert using signer url: %s", config.SignerUrl)

getResp, err := http.Get(config.SignerUrl + "cert/requests/" + certRequestID)
if err != nil {
return nil, fmt.Errorf("Didn't get a valid response: %s", err)
Expand All @@ -119,6 +122,7 @@ func downloadCert(config ssh_ca_util.RequesterConfig, certRequestID string, sshD
return nil, err
}
pubKeyPath = strings.Replace(pubKeyPath, ".pub", "-cert.pub", 1)
fmt.Printf("%s\n", getRespBuf)
err = ioutil.WriteFile(pubKeyPath, getRespBuf, 0644)
if err != nil {
fmt.Printf("Couldn't write certificate file to %s: %s\n", pubKeyPath, err)
Expand Down
4 changes: 3 additions & 1 deletion list_requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func listCertFlags() []cli.Flag {
if home == "" {
home = "/"
}
configPath := home + "/.ssh_ca/signer_config.json"
configPath := home + "/.ssh_ca/requester_config.json"

return []cli.Flag{
cli.StringFlag{
Expand Down Expand Up @@ -51,6 +51,8 @@ func listCerts(c *cli.Context) error {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
}
config := wrongTypeConfig.(ssh_ca_util.RequesterConfig)

ssh_ca_util.StartTunnelIfNeeded(&config)

getResp, err := http.Get(config.SignerUrl + "cert/requests")
if err != nil {
Expand Down
17 changes: 8 additions & 9 deletions request_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ func requestCertFlags() []cli.Flag {
Name: "quiet",
Usage: "Print only the request id on success",
},
cli.BoolTFlag{
Name: "add-key",
Usage: "When set automatically call ssh-add if cert was auto-signed by server",
cli.BoolFlag{
Name: "no-get-key",
Usage: "When set don't automatically download the key",
},
cli.StringFlag{
Name: "ssh-dir",
Expand All @@ -94,6 +94,8 @@ func requestCert(c *cli.Context) error {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
}
config := wrongTypeConfig.(ssh_ca_util.RequesterConfig)

ssh_ca_util.StartTunnelIfNeeded(&config)

reason := c.String("reason")
if reason == "" {
Expand Down Expand Up @@ -176,15 +178,12 @@ func requestCert(c *cli.Context) error {
appendage = " auto-signed"
}
fmt.Printf("Cert request id: %s%s\n", requestID, appendage)
if signed && c.BoolT("add-key") {
cert, err := downloadCert(config, requestID, sshDir)
if err != nil {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
}
err = addCertToAgent(cert, sshDir)
if signed && !c.Bool("no-get-key") {
_, err := downloadCert(config, requestID, sshDir)
if err != nil {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
}
// add cert to agent didn't seem to work and seemed unnecessary
}
}
} else {
Expand Down
8 changes: 5 additions & 3 deletions sign_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func signCertFlags() []cli.Flag {
if home == "" {
home = "/"
}
configPath := home + "/.ssh_ca/signer_config.json"
configPath := home + "/.ssh_ca/requester_config.json"

return []cli.Flag{
cli.StringFlag{
Expand All @@ -53,7 +53,7 @@ func signCertFlags() []cli.Flag {

func signCert(c *cli.Context) error {
configPath := c.String("config-file")
allConfig := make(map[string]ssh_ca_util.SignerConfig)
allConfig := make(map[string]ssh_ca_util.RequesterConfig)
err := ssh_ca_util.LoadConfig(configPath, &allConfig)
if err != nil {
return cli.NewExitError(fmt.Sprintf("Load Config failed: %s", err), 1)
Expand All @@ -71,7 +71,9 @@ func signCert(c *cli.Context) error {
if err != nil {
return cli.NewExitError(fmt.Sprintf("%s", err), 1)
}
config := wrongTypeConfig.(ssh_ca_util.SignerConfig)
config := wrongTypeConfig.(ssh_ca_util.RequesterConfig)

ssh_ca_util.StartTunnelIfNeeded(&config)

conn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
Expand Down
27 changes: 5 additions & 22 deletions util/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ type RequesterConfig struct {
PublicKeyPath string `json:",omitempty"`
PublicKeyFingerprint string `json:",omitempty"`
SignerUrl string
SshBastion string `json:",omitempty"`
KeyFingerprint string `json:",omitempty"`
}

type SignerdConfig struct {
Expand All @@ -25,19 +27,14 @@ type SignerdConfig struct {
CriticalOptions map[string]string
}

type SignerConfig struct {
KeyFingerprint string
SignerUrl string
}

func LoadConfig(configPath string, environmentConfigs interface{}) error {
buf, err := ioutil.ReadFile(configPath)
if err != nil {
return err
}

switch configType := environmentConfigs.(type) {
case *map[string]RequesterConfig, *map[string]SignerConfig, *map[string]SignerdConfig:
case *map[string]RequesterConfig, *map[string]SignerdConfig:
return json.Unmarshal(buf, &environmentConfigs)
default:
return fmt.Errorf("oops: %T\n", configType)
Expand All @@ -56,24 +53,10 @@ func GetConfigForEnv(environment string, environmentConfigs interface{}) (interf
// lame way of extracting first and only key from a map?
}
}

config, ok := configs[environment]
if !ok {
return nil, fmt.Errorf("Requested environment not found in config file.")
}
return config, nil
case *map[string]SignerConfig:
configs := *environmentConfigs.(*map[string]SignerConfig)
if len(configs) > 1 && environment == "" {
return nil, fmt.Errorf("You must tell me which environment to use.")
}
if len(configs) == 1 && environment == "" {
for environment = range configs {
// lame way of extracting first and only key from a map?
}
}
config, ok := configs[environment]
if !ok {
return nil, fmt.Errorf("Requested environment not found in config file.")
return nil, fmt.Errorf("Requested environment not found in config file1.")
}
return config, nil
}
Expand Down
190 changes: 190 additions & 0 deletions util/sshtunnel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// copied from https://gist.github.com/svett/5d695dcc4cc6ad5dd275

package ssh_ca_util

import (
// "log"
// "bufio"
// "time"
"os"
"fmt"
"io"
"strings"
"strconv"
"net"
"sync"
"net/url"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)

type Endpoint struct {
Host string
Port int
}

func (endpoint *Endpoint) String() string {
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
}

type SSHtunnel struct {
Local *Endpoint
Server *Endpoint
Remote *Endpoint

Config *ssh.ClientConfig
}

func (tunnel *SSHtunnel) Start(out_port chan int) error {
listener, err := net.Listen("tcp", tunnel.Local.String())
if err != nil {
return err
}
//fmt.Println("Using local port:", listener.Addr().(*net.TCPAddr).Port)
//tunnel.LocalPort = listener.Addr().(*net.TCPAddr).Port
out_port <- listener.Addr().(*net.TCPAddr).Port
defer listener.Close()

for {
conn, err := listener.Accept()
if err != nil {
return err
}
go tunnel.forward(conn)
}
}

func (tunnel *SSHtunnel) forward(localConn net.Conn) {
serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
if err != nil {
fmt.Printf("Server dial error: %s\n", err)
return
}

remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
if err != nil {
fmt.Printf("Remote dial error: %s\n", err)
return
}

copyConn:=func(writer, reader net.Conn) {
_, err:= io.Copy(writer, reader)
if err != nil {
fmt.Printf("io.Copy error: %s", err)
}
}

go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
}

func SSHAgent() ssh.AuthMethod {
if sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers)
}
return nil
}

var (
jobIsRunning bool
JobIsrunningMu sync.Mutex
)

func StartTunnelIfNeeded(config *RequesterConfig) {
if len(config.SshBastion) > 0 {

JobIsrunningMu.Lock()
start := !jobIsRunning
jobIsRunning = true
JobIsrunningMu.Unlock()
if start {
if !strings.HasPrefix(config.SshBastion, "ssh://") {
fmt.Printf("Bastion host must start with ssh://. Exiting\n")
os.Exit(1)
}

bastion_parsed, err := url.Parse(config.SshBastion)
if err != nil {
fmt.Printf("url.Parse error for SshBastion: %s", err)
}

// Check to see if it's a nonstardard port
host_parts := strings.Split(bastion_parsed.Host, ":")
var ssh_port int
ssh_port = 22
if len(host_parts) == 2 {
var err error
ssh_port, err = strconv.Atoi(host_parts[1])
if err != nil {
fmt.Printf("strconv.Atoi error: %s", err)
}
}

// Get remote end information
remote_parsed, err := url.Parse(config.SignerUrl)
if err != nil {
fmt.Printf("url.Parse error on SignerUrl: %s", err)
}
remote_parts := strings.Split(remote_parsed.Host, ":")
if len(remote_parts) != 2 {
fmt.Printf("Missing port for SignerUrl. Exiting")
os.Exit(1)
}
remote_port, err := strconv.Atoi(remote_parts[1])
if err != nil {
fmt.Printf("strconv.Atoi error: %s", err)
}

//fmt.Printf("config stuff: %s, %d\n", host_parts[0], ssh_port)
//fmt.Printf("starting tunnel config...\n")
localEndpoint := &Endpoint{
Host: "localhost",
Port: 0,
}

serverEndpoint := &Endpoint{
Host: host_parts[0],
Port: ssh_port,
}

remoteEndpoint := &Endpoint{
Host: remote_parts[0],
Port: remote_port,
}

sshConfig := &ssh.ClientConfig{
User: bastion_parsed.User.Username(),
Auth: []ssh.AuthMethod{
SSHAgent(),
},
// TODO: fix this to actually check the trusted hosts
// https://utcc.utoronto.ca/~cks/space/blog/programming/GoSSHHostKeyCheckingNotes
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
}

tunnel := &SSHtunnel{
Config: sshConfig,
Local: localEndpoint,
Server: serverEndpoint,
Remote: remoteEndpoint,
}

//fmt.Printf("starting tunnel...\n")

out_port_chan := make(chan int)
go tunnel.Start(out_port_chan)
var local_port int
local_port = <- out_port_chan
//fmt.Printf("Using local port: %d\n", local_port)

//fmt.Printf("doing normal stuff...\n")

config.SignerUrl = fmt.Sprintf("%s://localhost:%d/", remote_parsed.Scheme, local_port)
//fmt.Printf("sshtunnel using signer url: %s", config.SignerUrl)
// end new stuff
}
}
}