Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"log"
"time"

"github.com/doubleunion/accesscontrol/router"
rpio "github.com/stianeikeland/go-rpio/v4"
Expand All @@ -15,18 +14,5 @@ func main() {
}
defer rpio.Close()

// Run updateIPAndRestart every minute in a separate thread
go func() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()

for range ticker.C {
err := router.UpdateIPAndRestart()
if err != nil {
log.Printf("Error in updateIPAndRestart: %v", err)
}
}
}()

router.RunRouter()
}
3 changes: 2 additions & 1 deletion door/door.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ import (
"bytes"
cryptoRand "crypto/rand"
"fmt"
rpio "github.com/stianeikeland/go-rpio/v4"
"log"
"sync"
"time"

rpio "github.com/stianeikeland/go-rpio/v4"
)

type Door struct {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/labstack/echo-jwt/v4 v4.2.0
github.com/labstack/echo/v4 v4.11.3
github.com/stianeikeland/go-rpio/v4 v4.6.0
golang.org/x/crypto v0.14.0
)

require (
Expand All @@ -16,7 +17,6 @@ require (
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect
golang.org/x/crypto v0.14.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
Expand Down
76 changes: 76 additions & 0 deletions router/local_ip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package router

import (
"log"
"net"
"net/http"
"sync"
"time"
"slices"

"github.com/labstack/echo/v4"
)

var (
localInternetAddresses []string
localInternetAddressesMutex sync.RWMutex
)

func beginLocalIPMonitoring() {
updateLocalIPAddresses()

go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()

for range ticker.C {
updateLocalIPAddresses()
}
}()
}

func updateLocalIPAddresses() {
ips, err := net.LookupIP("doorcontrol.doubleunion.org")
if err != nil {
log.Printf("Error in updateLocalIPAddresses: %v", err)
}

var ipStrings []string
for _, ip := range ips {
ipStrings = append(ipStrings, ip.String())
}

localInternetAddressesMutex.Lock()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it possibly might be better to only take a read lock here until we know things have changed, since we're polling pretty regularly, but I'd guess in practice it mostly doesn't make a difference – we're not exactly dealing with lots of concurrent requests

defer localInternetAddressesMutex.Unlock()
hasChanged := !slices.Equal(ipStrings, localInternetAddresses)
localInternetAddresses = ipStrings

if hasChanged {
log.Printf("Updated local IP addresses to: %v", ipStrings)
}
}

func requireLocalNetworkMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Get the remote address from the request
remoteAddr := c.Request().RemoteAddr

// Parse the IP address
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to parse IP address")
}

localInternetAddressesMutex.RLock()
defer localInternetAddressesMutex.RUnlock()
for _, allowedIP := range localInternetAddresses {
if ip == allowedIP {
// Continue to the next middleware or route handler
return next(c)
}
}

// Fail the request, there was no matching IP
return jsonResponse(c, http.StatusForbidden, "requests not allowed from remote hosts")
}
}
97 changes: 1 addition & 96 deletions router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@ import (
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/exec"
"strings"
"time"

"github.com/doubleunion/accesscontrol/door"
Expand All @@ -21,20 +18,13 @@ import (
"golang.org/x/crypto/acme/autocert"
)

const serviceFilePath = "/etc/systemd/system/accesscontrol.service"
const ipQueryURL = "https://wtfismyip.com/text"

var localInternetAddress = os.Getenv("LOCAL_INTERNET_ADDRESS")

func RunRouter() {
signingKey := os.Getenv("ACCESS_CONTROL_SIGNING_KEY")
if signingKey == "" {
log.Fatal("signing key is missing")
}

if localInternetAddress == "" {
log.Fatal("local internet address is missing")
}
beginLocalIPMonitoring()

door := door.New()

Expand Down Expand Up @@ -96,88 +86,3 @@ func RunRouter() {
func jsonResponse(c echo.Context, code int, message string) error {
return c.JSON(code, map[string]string{"message": message})
}

func requireLocalNetworkMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
// Get the remote address from the request
remoteAddr := c.Request().RemoteAddr

// Parse the IP address
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to parse IP address")
}

if ip != localInternetAddress {
return jsonResponse(c, http.StatusForbidden, "requests not allowed from remote hosts")
}

// Continue to the next middleware or route handler
return next(c)
}
}

func UpdateIPAndRestart() error {
// Step 1: Query current IP address
resp, err := http.Get(ipQueryURL)
if err != nil {
return err
}
defer resp.Body.Close()

ipBytes, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
currentIP := strings.TrimSpace(string(ipBytes))

// Step 2: Read the service file
content, err := os.ReadFile(serviceFilePath)
if err != nil {
return err
}

// Step 3: Check if the IP matches
lines := strings.Split(string(content), "\n")
var updatedContent []string
ipUpdated := false
for _, line := range lines {
if strings.HasPrefix(line, "Environment=LOCAL_INTERNET_ADDRESS=") {
fileIP := strings.TrimPrefix(line, "Environment=LOCAL_INTERNET_ADDRESS=")
if fileIP != currentIP {
line = "Environment=LOCAL_INTERNET_ADDRESS=" + currentIP
ipUpdated = true
}
}
updatedContent = append(updatedContent, line)
}

// Step 4: Update the file if necessary
if ipUpdated {
// first we have to output the new contents to a temporary file
// because we don't have access to the service file directly
tempFilePath := "/tmp/accesscontrol.service"
err = os.WriteFile(tempFilePath, []byte(strings.Join(updatedContent, "\n")), 0644)
if err != nil {
return err
}

// then we copy the temporary file to the service file path
// the path is owned by the process user so this is allowed by the OS without sudo
cmd := exec.Command("cp", tempFilePath, serviceFilePath)
err = cmd.Run()
if err != nil {
return err
}

// Step 5: Restart the Raspberry Pi
cmd = exec.Command("sudo", "shutdown", "-r", "now")
//log.Printf("Error in updateIPAndRestart: %v", err)
err = cmd.Run()
if err != nil {
return err
}
}

return nil
}