@@ -20,70 +20,63 @@ const (
2020 date = "unknown"
2121)
2222
23- var options struct {
24- Signer signer.Options `group:"Vault SSH key signing Options"`
25- OpenSSH openssh.Options `group:"OpenSSH ssh(1) Options" hidden:"yes"`
26- Version func () `long:"version" group:"Help" description:"show version"`
27- }
23+ func init () {
24+ var options struct {
25+ Signer signer.Options `group:"Vault SSH key signing Options"`
26+ OpenSSH openssh.Options `group:"OpenSSH ssh(1) Options" hidden:"yes"`
27+ Version func () `long:"version" description:"Show version"`
28+ }
2829
29- func main () {
3030 options .Version = func () {
3131 fmt .Printf ("vault-ssh-plus v%s (%s), %s\n " , version , commit , date )
3232 os .Exit (0 )
3333 }
34+
3435 parser := flags .NewParser (& options , flags .Default )
3536 parser .Usage = "[options] destination [command]"
3637 if _ , err := parser .ParseArgs (os .Args [1 :]); err != nil {
3738 if flagsErr , ok := err .(* flags.Error ); ok && flagsErr .Type == flags .ErrHelp {
3839 os .Exit (0 )
3940 } else {
40- fmt .Println (err )
4141 os .Exit (1 )
4242 }
4343 }
44+ }
45+
46+ func main () {
47+ var (
48+ vaultClient signer.Client
49+ sshClient openssh.Client
50+ err error
51+ )
4452
45- vaultClient , unparsedArgs , err := signer .ParseArgs (os .Args [1 :])
53+ unparsedArgs , err := signer .ParseArgs (& vaultClient , os .Args [1 :])
4654 if err != nil {
4755 log .Fatal ("[ERROR] parsing vault options: " , err )
4856 }
4957
50- sshClient , _ , err : = openssh .ParseArgs (unparsedArgs )
58+ _ , err = openssh .ParseArgs (& sshClient , unparsedArgs )
5159 if err != nil {
5260 log .Fatal ("[ERROR] parsing ssh options: " , err )
5361 }
5462
5563 if sshClient .Options .LoginName == "" {
56- currentUser , _ := user .Current ()
57- sshClient .Options .LoginName = currentUser .Username
64+ sshClient .Options .LoginName = getDefaultUser (& vaultClient , & sshClient )
5865 }
5966
6067 controlConnection := sshClient .ControlConnection ()
6168
6269 if ! controlConnection && sshClient .Options .ControlCommand != "exit" {
63- if ! vaultClient .Options .Extensions .PortForwarding &&
64- (sshClient .Options .ProxyJump != "" ||
65- sshClient .Options .DynamicForward != nil ||
66- sshClient .Options .LocalForward != nil ||
67- sshClient .Options .RemoteForward != nil ) {
68- vaultClient .Options .Extensions .PortForwarding = true
69- }
70-
71- if ! vaultClient .Options .Extensions .NoPTY &&
72- (sshClient .Options .NoPTY ||
73- (sshClient .Options .ForcePTY == nil && len (sshClient .Options .Positional .RemoteCommand ) > 0 )) {
74- vaultClient .Options .Extensions .NoPTY = true
75- }
76-
77- if ! vaultClient .Options .Extensions .X11Forwarding && sshClient .Options .ForwardX11 {
78- vaultClient .Options .Extensions .X11Forwarding = true
79- }
70+ updateRequestExtensions (& vaultClient .Options .Extensions , & sshClient .Options )
8071
8172 signedKey , err := vaultClient .GetSignedKey (sshClient .Options .LoginName )
8273 if err != nil {
8374 log .Fatal ("[ERROR] failed to get signed key: " , err )
8475 }
8576
86- sshClient .SetSignedKey (signedKey )
77+ if err := sshClient .SetSignedKey (signedKey ); err != nil {
78+ log .Fatal ("[ERROR] invalid certificate: " , err )
79+ }
8780
8881 signedKeyFile , err := sshClient .WriteSignedKeyFile (
8982 filepath .Dir (vaultClient .Options .PublicKey ),
@@ -110,10 +103,50 @@ func main() {
110103
111104func setupExitHandler (fn string ) {
112105 s := make (chan os.Signal )
113- signal .Notify (s , os .Interrupt , syscall .SIGTERM , syscall .SIGQUIT )
106+ signal .Notify (s , os .Interrupt , os . Kill , syscall .SIGTERM , syscall .SIGQUIT )
114107 go func () {
115108 <- s
116109 _ = os .Remove (fn )
117110 os .Exit (0 )
118111 }()
119112}
113+
114+ func getDefaultUser (vaultClient * signer.Client , sshClient * openssh.Client ) string {
115+ var loginName string
116+
117+ // if the role only allows a single, fixed user, use it
118+ allowedUser := vaultClient .GetAllowedUser ()
119+ if allowedUser != "" {
120+ loginName = allowedUser
121+ sshClient .PrependArgs ([]string {"-l" , allowedUser })
122+ }
123+
124+ if loginName == "" {
125+ currentUser , _ := user .Current ()
126+ loginName = currentUser .Username
127+ }
128+
129+ return loginName
130+ }
131+
132+ func updateRequestExtensions (requestExtensions * signer.Extensions , sshOptions * openssh.Options ) {
133+ if ! requestExtensions .AgentForwarding && sshOptions .ForwardAgent {
134+ requestExtensions .AgentForwarding = true
135+ } else if requestExtensions .AgentForwarding && sshOptions .NoForwardAgent {
136+ requestExtensions .AgentForwarding = false
137+ }
138+
139+ if ! requestExtensions .PortForwarding &&
140+ (sshOptions .ProxyJump != "" || sshOptions .DynamicForward != nil || sshOptions .LocalForward != nil || sshOptions .RemoteForward != nil ) {
141+ requestExtensions .PortForwarding = true
142+ }
143+
144+ if ! requestExtensions .NoPTY &&
145+ (sshOptions .NoPTY || (sshOptions .ForcePTY == nil && len (sshOptions .Positional .RemoteCommand ) > 0 )) {
146+ requestExtensions .NoPTY = true
147+ }
148+
149+ if ! requestExtensions .X11Forwarding && sshOptions .ForwardX11 {
150+ requestExtensions .X11Forwarding = true
151+ }
152+ }
0 commit comments