Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sshconfig missing bugfix - addresses issue #1105 #1109

Merged
merged 15 commits into from
Oct 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 62 additions & 50 deletions libvirt/uri/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"log"
"net"
"os"
"os/user"
"path/filepath"
"strings"

Expand Down Expand Up @@ -39,19 +38,22 @@ func (u *ConnectionURI) parseAuthMethods(target string, sshcfg *ssh_config.Confi
// 2. load override as specified in ssh config
// 3. load default ssh keyfile path
sshKeyPaths := []string{}

sshKeyPath := q.Get("keyfile")
if sshKeyPath != "" {
sshKeyPaths = append(sshKeyPaths, sshKeyPath)
}

keyPaths, err := sshcfg.GetAll(target, "IdentityFile")
if err != nil {
log.Printf("[WARN] unable to get IdentityFile values - ignoring")
} else {
sshKeyPaths = append(sshKeyPaths, keyPaths...)
if sshcfg != nil {
keyPaths, err := sshcfg.GetAll(target, "IdentityFile")
if err != nil {
log.Printf("[WARN] unable to get IdentityFile values - ignoring")
} else {
sshKeyPaths = append(sshKeyPaths, keyPaths...)
}
}

if len(keyPaths) == 0 {
if len(sshKeyPaths) == 0 {
log.Printf("[DEBUG] found no ssh keys, using default keypath")
sshKeyPaths = []string{defaultSSHKeyPath}
}
Expand Down Expand Up @@ -116,14 +118,17 @@ func (u *ConnectionURI) parseAuthMethods(target string, sshcfg *ssh_config.Confi
// construct the whole ssh connection, which can consist of multiple hops if using proxy jumps,
// the ssh configuration file is loaded once and passed along to each host connection.
func (u *ConnectionURI) dialSSH() (net.Conn, error) {
var sshcfg* ssh_config.Config = nil

sshConfigFile, err := os.Open(os.ExpandEnv(defaultSSHConfigFile))
if err != nil {
log.Printf("[WARN] Failed to open ssh config file: %v", err)
}
} else {
sshcfg, err = ssh_config.Decode(sshConfigFile)
if err != nil {
log.Printf("[WARN] Failed to parse ssh config file: '%v' - sshconfig will be ignored.", err)
}

sshcfg, err := ssh_config.Decode(sshConfigFile)
if err != nil {
log.Printf("[WARN] Failed to parse ssh config file: %v", err)
}

// configuration loaded, build tunnel
Expand Down Expand Up @@ -164,11 +169,11 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
log.Printf("[DEBUG] ssh Port is overridden to: '%s'", port)
}

hostName, err := sshcfg.Get(target, "HostName")
if err == nil {
if hostName == "" {
hostName = target
} else {
hostName := target
if sshcfg != nil {
host, err := sshcfg.Get(target, "HostName")
if err == nil && host != "" {
hostName = host
log.Printf("[DEBUG] HostName is overridden to: '%s'", hostName)
}
}
Expand All @@ -182,18 +187,22 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
if knownHostsVerify == "ignore" {
skipVerify = true
} else {
strictCheck, err := sshcfg.Get(target, "StrictHostKeyChecking")
if err != nil && strictCheck == "yes" {
skipVerify = false
if sshcfg != nil {
strictCheck, err := sshcfg.Get(target, "StrictHostKeyChecking")
if err != nil && strictCheck == "yes" {
skipVerify = false
}
}
}

if knownHostsPath == "" {
knownHosts, err := sshcfg.Get(target, "UserKnownHostsFile")
if err == nil && knownHosts != "" {
knownHostsPath = knownHosts
} else {
knownHostsPath = defaultSSHKnownHostsPath
knownHostsPath = defaultSSHKnownHostsPath

if sshcfg != nil {
knownHosts, err := sshcfg.Get(target, "UserKnownHostsFile")
if err == nil && knownHosts != "" {
knownHostsPath = knownHosts
}
}
}

Expand Down Expand Up @@ -226,10 +235,12 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
return err
}

keyAlgs, err := sshcfg.Get(target, "HostKeyAlgorithms")
if err == nil && keyAlgs != "" {
log.Printf("Got host key algorithms '%s'", keyAlgs)
hostKeyAlgorithms = strings.Split(keyAlgs, ",")
if sshcfg != nil {
keyAlgs, err := sshcfg.Get(target, "HostKeyAlgorithms")
if err == nil && keyAlgs != "" {
log.Printf("[DEBUG] HostKeyAlgorithms is overridden to '%s'", keyAlgs)
hostKeyAlgorithms = strings.Split(keyAlgs, ",")
}
}

}
Expand All @@ -240,46 +251,47 @@ func (u *ConnectionURI) dialHost(target string, sshcfg *ssh_config.Config, depth
HostKeyAlgorithms: hostKeyAlgorithms,
Timeout: dialTimeout,
}
var bastion *ssh.Client = nil
var bastion_proxy string = ""

proxy, err := sshcfg.Get(target, "ProxyCommand")
if err == nil && proxy != "" {
log.Printf("[WARNING] unsupported ssh ProxyCommand '%v'", proxy)
if sshcfg != nil {
command, err := sshcfg.Get(target, "ProxyCommand")
if err == nil && command != "" {
log.Printf("[WARNING] unsupported ssh ProxyCommand '%v' - ignoring", command)
}
}

proxy, err = sshcfg.Get(target, "ProxyJump")
var bastion *ssh.Client
if err == nil && proxy != "" {
log.Printf("[DEBUG] found ProxyJump '%v'", proxy)

// this is a proxy jump: we recurse into that proxy
bastion, err = u.dialHost(proxy, sshcfg, depth+1)
if err != nil {
return nil, fmt.Errorf("failed to connect to bastion host '%v': %w", proxy, err)
if sshcfg != nil {
proxy, err := sshcfg.Get(target, "ProxyJump")
if err == nil && proxy != "" {
log.Printf("[DEBUG] found ProxyJump '%v'", proxy)
// this is a proxy jump: we recurse into that proxy
bastion, err = u.dialHost(proxy, sshcfg, depth+1)
bastion_proxy = proxy
if err != nil {
return nil, fmt.Errorf("failed to connect to bastion host '%v': %w", proxy, err)
}
}
}

if cfg.User == "" {
// cfg.User value defaults to u.User.Username()
if sshcfg != nil {
sshu, err := sshcfg.Get(target, "User")
log.Printf("[DEBUG] SSH User for target '%v' is '%v'", target, sshu)
if err != nil {
log.Printf("[DEBUG] ssh user: using current login")
u, err := user.Current()
if err != nil {
return nil, fmt.Errorf("unable to get username: %w", err)
}
sshu = u.Username
log.Printf("[DEBUG] ssh user for target '%v' is overridden to '%v'", target, sshu)
cfg.User = sshu
}
cfg.User = sshu
}


cfg.Auth = u.parseAuthMethods(target, sshcfg)
if len(cfg.Auth) < 1 {
return nil, fmt.Errorf("could not configure SSH authentication methods")
}

if bastion != nil {
// if this is a proxied connection, we want to dial through the bastion host
log.Printf("[INFO] SSH connecting to '%v' (%v) through bastion host '%v'", target, hostName, proxy)
log.Printf("[INFO] SSH connecting to '%v' (%v) through bastion host '%v'", target, hostName, bastion_proxy)
// Dial a connection to the service host, from the bastion
conn, err := bastion.Dial("tcp", net.JoinHostPort(hostName, port))
if err != nil {
Expand Down
Loading