Skip to content

Commit f010e4e

Browse files
committed
Use WMI to implement Smb API to reduce PowerShell overhead
1 parent f3aa922 commit f010e4e

File tree

4 files changed

+177
-36
lines changed

4 files changed

+177
-36
lines changed

pkg/cim/smb.go

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
//go:build windows
2+
// +build windows
3+
4+
package cim
5+
6+
import (
7+
"strings"
8+
9+
"github.com/microsoft/wmi/pkg/base/query"
10+
cim "github.com/microsoft/wmi/pkg/wmiinstance"
11+
)
12+
13+
// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/smb/msft-smbmapping
14+
const (
15+
SmbMappingStatusOK int32 = iota
16+
SmbMappingStatusPaused
17+
SmbMappingStatusDisconnected
18+
SmbMappingStatusNetworkError
19+
SmbMappingStatusConnecting
20+
SmbMappingStatusReconnecting
21+
SmbMappingStatusUnavailable
22+
23+
credentialDelimiter = ":"
24+
)
25+
26+
// escapeQueryParameter escapes a parameter for WMI Queries
27+
func escapeQueryParameter(s string) string {
28+
s = strings.ReplaceAll(s, "'", "''")
29+
s = strings.ReplaceAll(s, "\\", "\\\\")
30+
return s
31+
}
32+
33+
func escapeUserName(userName string) string {
34+
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170
35+
userName = strings.ReplaceAll(userName, "\\", "\\\\")
36+
userName = strings.ReplaceAll(userName, credentialDelimiter, "\\"+credentialDelimiter)
37+
return userName
38+
}
39+
40+
// QuerySmbGlobalMappingByRemotePath retrieves the SMB global mapping from its remote path.
41+
//
42+
// The equivalent WMI query is:
43+
//
44+
// SELECT [selectors] FROM MSFT_SmbGlobalMapping
45+
//
46+
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
47+
// for the WMI class definition.
48+
func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, error) {
49+
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath))
50+
instances, err := QueryInstances(WMINamespaceSmb, smbQuery)
51+
if err != nil {
52+
return nil, err
53+
}
54+
55+
return instances[0], err
56+
}
57+
58+
// GetSmbGlobalMappingStatus returns the status of an SMB global mapping.
59+
func GetSmbGlobalMappingStatus(inst *cim.WmiInstance) (int32, error) {
60+
statusProp, err := inst.GetProperty("Status")
61+
if err != nil {
62+
return SmbMappingStatusUnavailable, err
63+
}
64+
65+
return statusProp.(int32), nil
66+
}
67+
68+
// RemoveSmbGlobalMappingByRemotePath removes an SMB global mapping matching to the remote path.
69+
//
70+
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
71+
// for the WMI class definition.
72+
func RemoveSmbGlobalMappingByRemotePath(remotePath string) error {
73+
smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath))
74+
instances, err := QueryInstances(WMINamespaceSmb, smbQuery)
75+
if err != nil {
76+
return err
77+
}
78+
79+
_, err = instances[0].InvokeMethod("Remove", true)
80+
return err
81+
}
82+
83+
// NewSmbGlobalMapping creates a new SMB global mapping to the remote path.
84+
//
85+
// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping
86+
// for the WMI class definition.
87+
func NewSmbGlobalMapping(remotePath, username, password string, requirePrivacy bool) (int, error) {
88+
params := map[string]interface{}{
89+
"RemotePath": remotePath,
90+
"RequirePrivacy": requirePrivacy,
91+
}
92+
if username != "" {
93+
// refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178
94+
// on how SMB credential is handled in PowerShell
95+
params["Credential"] = escapeUserName(username) + credentialDelimiter + password
96+
}
97+
98+
result, _, err := InvokeCimMethod(WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params)
99+
return result, err
100+
}

pkg/cim/wmi.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
const (
2121
WMINamespaceCimV2 = "Root\\CimV2"
2222
WMINamespaceStorage = "Root\\Microsoft\\Windows\\Storage"
23+
WMINamespaceSmb = "Root\\Microsoft\\Windows\\Smb"
2324
)
2425

2526
type InstanceHandler func(instance *cim.WmiInstance) (bool, error)

pkg/smb/hostapi/hostapi.go

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ import (
44
"fmt"
55
"strings"
66

7+
"github.com/kubernetes-csi/csi-proxy/v2/pkg/cim"
78
"github.com/kubernetes-csi/csi-proxy/v2/pkg/utils"
9+
"github.com/microsoft/wmi/pkg/base/query"
810
)
911

1012
type HostAPI interface {
@@ -22,62 +24,65 @@ func New() HostAPI {
2224
return smbAPI{}
2325
}
2426

27+
func remotePathForQuery(remotePath string) string {
28+
return strings.ReplaceAll(remotePath, "\\", "\\\\")
29+
}
30+
2531
func (smbAPI) IsSMBMapped(remotePath string) (bool, error) {
26-
cmdLine := `$(Get-SmbGlobalMapping -RemotePath $Env:smbremotepath -ErrorAction Stop).Status `
27-
cmdEnv := fmt.Sprintf("smbremotepath=%s", remotePath)
28-
out, err := utils.RunPowershellCmd(cmdLine, cmdEnv)
29-
if err != nil {
30-
return false, fmt.Errorf("error checking SMB mapping. cmd %s, output: %s, err: %v", remotePath, string(out), err)
31-
}
32+
var isMapped bool
33+
err := cim.WithCOMThread(func() error {
34+
inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath)
35+
if err != nil {
36+
return err
37+
}
3238

33-
if len(out) == 0 || !strings.EqualFold(strings.TrimSpace(string(out)), "OK") {
34-
return false, nil
35-
}
36-
return true, nil
39+
status, err := cim.GetSmbGlobalMappingStatus(inst)
40+
if err != nil {
41+
return err
42+
}
43+
44+
isMapped = status == cim.SmbMappingStatusOK
45+
return nil
46+
})
47+
return isMapped, cim.IgnoreNotFound(err)
3748
}
3849

3950
// NewSMBLink - creates a directory symbolic link to the remote share.
4051
// The os.Symlink was having issue for cases where the destination was an SMB share - the container
41-
// runtime would complain stating "Access Denied". Because of this, we had to perform
42-
// this operation with powershell commandlet creating an directory softlink.
43-
// Since os.Symlink is currently being used in working code paths, no attempt is made in
44-
// alpha to merge the paths.
45-
// TODO (for beta release): Merge the link paths - os.Symlink and Powershell link path.
52+
// runtime would complain stating "Access Denied".
4653
func (smbAPI) NewSMBLink(remotePath, localPath string) error {
4754
if !strings.HasSuffix(remotePath, "\\") {
4855
// Golang has issues resolving paths mapped to file shares if they do not end in a trailing \
4956
// so add one if needed.
5057
remotePath = remotePath + "\\"
5158
}
59+
longRemotePath := utils.EnsureLongPath(remotePath)
60+
longLocalPath := utils.EnsureLongPath(localPath)
5261

53-
cmdLine := `New-Item -ItemType SymbolicLink $Env:smblocalPath -Target $Env:smbremotepath`
54-
output, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("smbremotepath=%s", remotePath), fmt.Sprintf("smblocalpath=%s", localPath))
62+
err := utils.CreateSymlink(longLocalPath, longRemotePath, true)
5563
if err != nil {
56-
return fmt.Errorf("error linking %s to %s. output: %s, err: %v", remotePath, localPath, string(output), err)
64+
return fmt.Errorf("error linking %s to %s. err: %v", remotePath, localPath, err)
5765
}
5866

5967
return nil
6068
}
6169

6270
func (smbAPI) NewSMBGlobalMapping(remotePath, username, password string) error {
63-
// use PowerShell Environment Variables to store user input string to prevent command line injection
64-
// https://docs.microsoft.com/en-us/powershell/module/microsoft.powershell.core/about/about_environment_variables?view=powershell-5.1
65-
cmdLine := fmt.Sprintf(`$PWord = ConvertTo-SecureString -String $Env:smbpassword -AsPlainText -Force` +
66-
`;$Credential = New-Object -TypeName System.Management.Automation.PSCredential -ArgumentList $Env:smbuser, $PWord` +
67-
`;New-SmbGlobalMapping -RemotePath $Env:smbremotepath -Credential $Credential -RequirePrivacy $true`)
68-
69-
if output, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("smbuser=%s", username),
70-
fmt.Sprintf("smbpassword=%s", password),
71-
fmt.Sprintf("smbremotepath=%s", remotePath)); err != nil {
72-
return fmt.Errorf("NewSMBGlobalMapping failed. output: %q, err: %v", string(output), err)
73-
}
74-
return nil
71+
return cim.WithCOMThread(func() error {
72+
result, err := cim.NewSmbGlobalMapping(remotePath, username, password, api.RequirePrivacy)
73+
if err != nil {
74+
return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err)
75+
}
76+
return nil
77+
})
7578
}
7679

7780
func (smbAPI) RemoveSMBGlobalMapping(remotePath string) error {
78-
cmd := `Remove-SmbGlobalMapping -RemotePath $Env:smbremotepath -Force`
79-
if output, err := utils.RunPowershellCmd(cmd, fmt.Sprintf("smbremotepath=%s", remotePath)); err != nil {
80-
return fmt.Errorf("UnmountSMBShare failed. output: %q, err: %v", string(output), err)
81-
}
82-
return nil
81+
return cim.WithCOMThread(func() error {
82+
err := cim.RemoveSmbGlobalMappingByRemotePath(remotePath)
83+
if err != nil {
84+
return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err)
85+
}
86+
return nil
87+
})
8388
}

pkg/utils/utils.go

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package utils
22

33
import (
4-
"errors"
54
"fmt"
65
"os"
76
"os/exec"
87
"strings"
98

9+
"github.com/pkg/errors"
1010
"golang.org/x/sys/windows"
1111
"k8s.io/klog/v2"
1212
)
@@ -76,3 +76,38 @@ func IsPathSymlink(path string) (bool, error) {
7676
isSymlink := fi.Mode()&os.ModeSymlink != 0 || fi.Mode()&os.ModeIrregular != 0
7777
return isSymlink, nil
7878
}
79+
80+
func CreateSymlink(link, target string, isDir bool) error {
81+
linkPtr, err := windows.UTF16PtrFromString(link)
82+
if err != nil {
83+
return err
84+
}
85+
targetPtr, err := windows.UTF16PtrFromString(target)
86+
if err != nil {
87+
return err
88+
}
89+
90+
var flags uint32
91+
if isDir {
92+
flags = windows.SYMBOLIC_LINK_FLAG_DIRECTORY
93+
}
94+
95+
err = windows.CreateSymbolicLink(
96+
linkPtr,
97+
targetPtr,
98+
flags,
99+
)
100+
return err
101+
}
102+
103+
// PathExists checks whether the given `path` exists.
104+
func PathExists(path string) (bool, error) {
105+
_, err := os.Lstat(path)
106+
if err == nil {
107+
return true, nil
108+
}
109+
if os.IsNotExist(err) {
110+
return false, nil
111+
}
112+
return false, err
113+
}

0 commit comments

Comments
 (0)