Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions vpn/ipc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package vpn

import (
"context"
"fmt"
"runtime"

"github.com/getlantern/radiance/common"
"github.com/getlantern/radiance/vpn/ipc"
"github.com/sagernet/sing-box/experimental/clashapi"
"github.com/sagernet/sing-box/experimental/libbox"
)

var platIfceProvider func() libbox.PlatformInterface

// closedSvc is a stub service used while the tunnel is down
type closedSvc struct{}

func (closedSvc) Ctx() context.Context { return context.Background() }
func (closedSvc) Status() string { return ipc.StatusClosed }
func (closedSvc) ClashServer() *clashapi.Server { return nil }
func (closedSvc) Close() error { return nil }

// InitIPC starts the long-lived IPC server and hooks it up to establishConnection
func InitIPC(basePath string, provider func() libbox.PlatformInterface) error {
if ipcServer != nil {
// already started
return nil
}
platIfceProvider = provider
if runtime.GOOS != "windows" && basePath != "" {
ipc.SetSocketPath(basePath)
}

ipcServer = ipc.NewServer(closedSvc{})
// start tunnel via IPC. How /service/start brings the tunnel up
ipcServer.SetStartFn(func(ctx context.Context, group, tag string) (ipc.Service, error) {
path := basePath
if path == "" {
path = common.DataPath()
}

_ = newSplitTunnel(path)

opts, err := buildOptions(group, path)
if err != nil {
return nil, fmt.Errorf("build options: %w", err)
}

var pi libbox.PlatformInterface
if platIfceProvider != nil {
pi = platIfceProvider()
}

if err := establishConnection(group, tag, opts, path, pi); err != nil {
return nil, err
}
return tInstance, nil
})
return ipcServer.Start(basePath)
}
95 changes: 44 additions & 51 deletions vpn/ipc/conn_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"fmt"
"net"
"syscall"
"time"

"github.com/Microsoft/go-winio"
Expand All @@ -20,9 +21,7 @@ const (
apiURL = "http://pipe"
connectTimeout = 10 * time.Second

// TODO: I don't know which one should be used
sddl = `O:BAG:BAD:PAI(A;OICI;GWGR;;;BU)(A;OICI;GWGR;;;SY)`
// sddl = `D:P(A;;GA;;;SY)(A;;GA;;;BA)(A;;GA;;;BU)`
sddl = `D:P(A;;GA;;;SY)(A;;GRGW;;;IU)(A;;GRGW;;;BA)`
)

// SetSocketPath not supported on Windows.
Expand Down Expand Up @@ -52,11 +51,24 @@ func listen(_ string) (net.Listener, error) {
return &winioListener{ln}, nil
}

// winioConn is an helper interface to access the underlying file descriptor of a winio.Conn. This
// is needed to call Windows API functions that require a handle.
// winioConn is an helper interface that exposes the standard syscall.Conn so we can
// access the underlying handle via RawConn.Control
type winioConn interface {
net.Conn
FD() uintptr
SyscallConn() (syscall.RawConn, error)
}

// withConnHandle runs the function with the connection’s handle pinned by the runtime
func withConnHandle(c winioConn, fn func(h windows.Handle) error) error {
rc, err := c.SyscallConn()
if err != nil {
return err
}
var callErr error
if err := rc.Control(func(fd uintptr) { callErr = fn(windows.Handle(fd)) }); err != nil {
return err
}
return callErr
}

type winioListener struct {
Expand Down Expand Up @@ -86,7 +98,7 @@ func (l *winioListener) Accept() (conn net.Conn, err error) {
}
defer pipeToken.Close()

procToken, err := getProcessToken(wc)
procToken, err := getServerProcessToken()
if err != nil {
return nil, fmt.Errorf("failed to get process token: %w", err)
}
Expand All @@ -103,68 +115,49 @@ func (l *winioListener) Accept() (conn net.Conn, err error) {
return wc, nil
}

func tokenUserSID(t windows.Token) (*windows.SID, error) {
u, err := t.GetTokenUser()
if err != nil {
return nil, fmt.Errorf("failed to get token user: %w", err)
}
return u.User.Sid, nil
}

// verifySameUser checks if two tokens belong to the same user.
func verifySameUser(t1, t2 windows.Token) (bool, error) {
u1, err := t1.GetTokenUser()
s1, err := tokenUserSID(t1)
if err != nil {
return false, fmt.Errorf("failed to get token user: %w", err)
}
u2, err := t2.GetTokenUser()
s2, err := tokenUserSID(t2)
if err != nil {
return false, fmt.Errorf("failed to get token user: %w", err)
}
return u1.User.Sid.Equals(u2.User.Sid), nil
return s1.Equals(s2), nil
}

// getPipeClientToken retrieves the impersonation token for the pipe client.
func getPipeClientToken(conn winioConn) (windows.Token, error) {
ph := windows.Handle(conn.FD())
if ph == 0 {
return 0, fmt.Errorf("invalid pipe handle")
}

err := impersonateNamedPipeClient(ph)
if err != nil {
return 0, fmt.Errorf("failed to impersonate client: %w", err)
}
defer windows.RevertToSelf()

var token windows.Token
err = windows.OpenThreadToken(windows.CurrentThread(), windows.TOKEN_DUPLICATE|windows.TOKEN_QUERY, true, &token)
if err != nil {
return 0, fmt.Errorf("failed to open thread token: %w", err)
if err := withConnHandle(conn, func(h windows.Handle) error {
err := impersonateNamedPipeClient(h)
if err != nil {
return fmt.Errorf("failed to impersonate client: %w", err)
}
defer windows.RevertToSelf()

return windows.OpenThreadToken(windows.CurrentThread(), windows.TOKEN_DUPLICATE|windows.TOKEN_QUERY, true, &token)
}); err != nil {
return 0, err
}
return token, nil
}

// getProcessToken retrieves the process token for the pipe client.
func getProcessToken(pc winioConn) (windows.Token, error) {
pid, err := getPipeClientPID(pc)
if err != nil {
return 0, fmt.Errorf("failed to get client process id: %w", err)
}
h, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, pid)
if err != nil {
return 0, fmt.Errorf("failed to open process token: %w", err)
}
defer windows.CloseHandle(h)

// getServerProcessToken retrieves the process token for the pipe client.
func getServerProcessToken() (windows.Token, error) {
var token windows.Token
if err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &token); err != nil {
return 0, fmt.Errorf("failed to open process token: %w", err)
if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil {
return 0, fmt.Errorf("failed to open service process token: %w", err)
}
return token, nil
}

func getPipeClientPID(pc winioConn) (uint32, error) {
ph := windows.Handle(pc.FD())
if ph == 0 {
return 0, fmt.Errorf("invalid pipe handle")
}
var pid uint32
err := windows.GetNamedPipeClientProcessId(ph, &pid)
if err != nil {
return 0, fmt.Errorf("failed to get client process id: %w", err)
}
return pid, nil
}
2 changes: 2 additions & 0 deletions vpn/ipc/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package ipc
const (
statusEndpoint = "/status"
metricsEndpoint = "/metrics"
startServiceEndpoint = "/service/start"
stopServiceEndpoint = "/service/stop"
closeServiceEndpoint = "/service/close"
groupsEndpoint = "/groups"
selectEndpoint = "/outbound/select"
Expand Down
76 changes: 71 additions & 5 deletions vpn/ipc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ package ipc

import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"time"

"github.com/getlantern/radiance/traces"
"github.com/go-chi/chi/v5"
"github.com/sagernet/sing-box/experimental/clashapi"

"github.com/getlantern/radiance/traces"
)

var (
Expand All @@ -30,13 +30,16 @@ type Service interface {
Close() error
}

type StartFn func(ctx context.Context, group, tag string) (Service, error)

// Server represents the IPC server that communicates over a Unix domain socket for Unix-like
// systems, and a named pipe for Windows.
type Server struct {
svr *http.Server
service Service

router chi.Router
router chi.Router
startFn StartFn
}

// NewServer creates a new Server instance with the provided Service.
Expand All @@ -55,6 +58,8 @@ func NewServer(service Service) *Server {
s.router.Post(selectEndpoint, s.selectHandler)
s.router.Get(clashModeEndpoint, s.clashModeHandler)
s.router.Post(clashModeEndpoint, s.clashModeHandler)
s.router.Post(startServiceEndpoint, s.startServiceHandler)
s.router.Post(stopServiceEndpoint, s.stopServiceHandler)
s.router.Post(closeServiceEndpoint, s.closeServiceHandler)
s.router.Post(closeConnectionsEndpoint, s.closeConnectionHandler)
return s
Expand Down Expand Up @@ -89,12 +94,71 @@ func (s *Server) Close() error {
return s.svr.Close()
}

// CloseService sends a request to shutdown the service. This will also close the IPC server.
// CloseService sends a request to shutdown the service
func CloseService(ctx context.Context) error {
_, err := sendRequest[empty](ctx, "POST", closeServiceEndpoint, nil)
return err
}

// StartService sends a request to start the service
func StartService(ctx context.Context, group, tag string) error {
_, err := sendRequest[empty](ctx, "POST", startServiceEndpoint, selection{GroupTag: group, OutboundTag: tag})
return err
}

// StopService sends a request to stop the service (IPC server stays up)
func StopService(ctx context.Context) error {
_, err := sendRequest[empty](ctx, "POST", stopServiceEndpoint, nil)
return err
}

// SetStartFn registers a function that will be called when the start endpoint is hit
func (s *Server) SetStartFn(fn StartFn) { s.startFn = fn }

// SetService updates the service attached to the server.
// Typically called when starting or replacing the VPN tunnel
func (s *Server) SetService(svc Service) { s.service = svc }

func (s *Server) startServiceHandler(w http.ResponseWriter, r *http.Request) {
if s.startFn == nil {
http.Error(w, "start not supported", http.StatusNotImplemented)
return
}
// check if service is already running
if s.service != nil && s.service.Status() == StatusRunning {
w.WriteHeader(http.StatusOK)
return
}
var p selection
if err := json.NewDecoder(r.Body).Decode(&p); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
svc, err := s.startFn(r.Context(), p.GroupTag, p.OutboundTag)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
s.SetService(svc)
w.WriteHeader(http.StatusOK)
}

func (s *Server) stopServiceHandler(w http.ResponseWriter, r *http.Request) {
svc := s.service
s.service = &closedService{}

if svc != nil {
if err := svc.Close(); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
w.WriteHeader(http.StatusOK)
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}

func (s *Server) closeServiceHandler(w http.ResponseWriter, r *http.Request) {
service := s.service
s.service = &closedService{}
Expand All @@ -108,7 +172,9 @@ func (s *Server) closeServiceHandler(w http.ResponseWriter, r *http.Request) {
}

go func() {
if err := s.Close(); err != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.svr.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
traces.RecordError(context.Background(), err)
}
}()
Expand Down
Loading