diff --git a/db-connector.go b/db-connector.go index 90a93f11..9d42fd02 100755 --- a/db-connector.go +++ b/db-connector.go @@ -4218,6 +4218,10 @@ func GetOrg(ctx context.Context, id string) (*Org, error) { } func init() { + // Skip import path check in test mode + if os.Getenv("SHUFFLE_TEST_MODE") == "true" { + return + } isValid := checkImportPath() if !isValid { @@ -4899,14 +4903,13 @@ func DeleteKey(ctx context.Context, entity string, value string) error { // Index = Username func SetApikey(ctx context.Context, Userdata User) error { + log.Printf("[AUDIT] Setting API key %s", Userdata.ApiKey) - // Non indexed User data newapiUser := new(Userapi) - newapiUser.ApiKey = Userdata.ApiKey newapiUser.Username = strings.ToLower(Userdata.Username) + newapiUser.ApiKey = Userdata.ApiKey nameKey := "apikey" - // New struct, to not add body, author etc if project.DbType == "opensearch" { data, err := json.Marshal(Userdata) if err != nil { @@ -5105,10 +5108,7 @@ func GetOpenApiDatastore(ctx context.Context, id string) (ParsedOpenApi, error) return *api, nil } -// Index = Username func SetSession(ctx context.Context, user User, value string) error { - //parsedKey := strings.ToLower(user.Username) - // Non indexed User data parsedKey := user.Id user.Session = value @@ -6221,7 +6221,7 @@ func fixUserOrg(ctx context.Context, user *User) *User { if !strings.Contains(err.Error(), "doesn't exist") { log.Printf("[WARNING] Error getting org %s in fixUserOrg: %s", orgId, err) } - + return } @@ -9738,17 +9738,34 @@ func GetSessionNew(ctx context.Context, sessionId string) (User, error) { } } - // Query for the specific API-key in users + sessionsToSearch := []string{sessionId} + encryptedSession, encErr := HandleKeyEncryption([]byte(sessionId), "session", true) + if encErr == nil { + sessionsToSearch = append([]string{string(encryptedSession)}, sessionsToSearch...) + } else { + log.Printf("[WARNING] Failed encrypting session: %s", encErr) + } + nameKey := "Users" var users []User if project.DbType == "opensearch" { + shouldClauses := make([]map[string]interface{}, len(sessionsToSearch)) + for i, sess := range sessionsToSearch { + shouldClauses[i] = map[string]interface{}{ + "match": map[string]interface{}{ + "session": sess, + }, + } + } + var buf bytes.Buffer query := map[string]interface{}{ "from": 0, "size": 1000, "query": map[string]interface{}{ - "match": map[string]interface{}{ - "session": sessionId, + "bool": map[string]interface{}{ + "should": shouldClauses, + "minimum_should_match": 1, }, }, } @@ -9770,7 +9787,7 @@ func GetSessionNew(ctx context.Context, sessionId string) (User, error) { return User{}, nil } - log.Printf("[ERROR] Error getting response from Opensearch (get api keys): %s", err) + log.Printf("[ERROR] Error getting response from Opensearch (get session): %s", err) return User{}, err } @@ -9813,27 +9830,38 @@ func GetSessionNew(ctx context.Context, sessionId string) (User, error) { users = []User{} for _, hit := range wrapped.Hits.Hits { - if hit.Source.Session != sessionId { + // Check if session matches any of our search keys + matched := false + for _, sess := range sessionsToSearch { + if hit.Source.Session == sess { + matched = true + break + } + } + if !matched { continue } - users = append(users, hit.Source) } } else { - //log.Printf("[DEBUG] Searching for session %s", sessionId) - q := datastore.NewQuery(nameKey).Filter("session =", sessionId).Limit(1) - _, err := project.Dbclient.GetAll(ctx, q, &users) - if err != nil && len(users) == 0 { - if !strings.Contains(err.Error(), `cannot load field`) { - log.Printf("[WARNING] Error getting session: %s", err) - return User{}, err + // Datastore: try encrypted first, then plain (no IN filter support) + for _, sess := range sessionsToSearch { + q := datastore.NewQuery(nameKey).Filter("session =", sess).Limit(1) + _, err := project.Dbclient.GetAll(ctx, q, &users) + if err != nil && len(users) == 0 { + if !strings.Contains(err.Error(), `cannot load field`) { + continue + } + } + if len(users) > 0 { + break } } } if len(users) == 0 { - return User{}, errors.New("No users found for this apikey (1)") + return User{}, errors.New("No users found for this session") } if project.CacheDb { @@ -9853,17 +9881,34 @@ func GetSessionNew(ctx context.Context, sessionId string) (User, error) { } func GetApikey(ctx context.Context, apikey string) (User, error) { - // Query for the specific API-key in users + // Build list of keys to search: encrypted (new) + plain (backwards compat) + keysToSearch := []string{apikey} + encryptedKey, encErr := HandleKeyEncryption([]byte(apikey), "apikey", true) + if encErr == nil { + keysToSearch = append([]string{string(encryptedKey)}, keysToSearch...) + } + nameKey := "Users" var users []User if project.DbType == "opensearch" { + // Build OR query for both encrypted and plain apikey + shouldClauses := make([]map[string]interface{}, len(keysToSearch)) + for i, key := range keysToSearch { + shouldClauses[i] = map[string]interface{}{ + "match": map[string]interface{}{ + "apikey": key, + }, + } + } + var buf bytes.Buffer query := map[string]interface{}{ "from": 0, "size": 1000, "query": map[string]interface{}{ - "match": map[string]interface{}{ - "apikey": apikey, + "bool": map[string]interface{}{ + "should": shouldClauses, + "minimum_should_match": 1, }, }, } @@ -9928,20 +9973,32 @@ func GetApikey(ctx context.Context, apikey string) (User, error) { users = []User{} for _, hit := range wrapped.Hits.Hits { - if hit.Source.ApiKey != apikey { + // Check if apikey matches any of our search keys + matched := false + for _, key := range keysToSearch { + if hit.Source.ApiKey == key { + matched = true + break + } + } + if !matched { continue } - users = append(users, hit.Source) } } else { - q := datastore.NewQuery(nameKey).Filter("apikey =", apikey).Limit(1) - _, err := project.Dbclient.GetAll(ctx, q, &users) - if err != nil && len(users) == 0 { - if !strings.Contains(err.Error(), `cannot load field`) { - log.Printf("[WARNING] Error getting apikey: %s", err) - return User{}, err + // Datastore: try encrypted first, then plain (no IN filter support) + for _, key := range keysToSearch { + q := datastore.NewQuery(nameKey).Filter("apikey =", key).Limit(1) + _, err := project.Dbclient.GetAll(ctx, q, &users) + if err != nil && len(users) == 0 { + if !strings.Contains(err.Error(), `cannot load field`) { + continue + } + } + if len(users) > 0 { + break } } } @@ -13965,7 +14022,7 @@ func GetDatastoreKey(ctx context.Context, id string, category string) (*CacheKey category = strings.ReplaceAll(strings.ToLower(category), " ", "_") if len(category) > 0 && category != "default" { - // FIXME: If they key itself is 'test_protected' and category + // FIXME: If they key itself is 'test_protected' and category // is 'protected' this breaks... Keeping it for now. if !strings.HasSuffix(id, fmt.Sprintf("_%s", category)) { id = fmt.Sprintf("%s_%s", id, category) @@ -14225,18 +14282,18 @@ func RunInit(dbclient datastore.Client, storageClient storage.Client, gceProject } else { //log.Printf("\n\n[INFO] Should check for SSO during setup - finding main org\n\n") /* - orgs, err := GetAllOrgs(ctx) - if err == nil { - for _, org := range orgs { - if len(org.ManagerOrgs) == 0 && len(org.SSOConfig.SSOEntrypoint) > 0 { - log.Printf("[INFO] Set initial SSO url for logins to %s", org.SSOConfig.SSOEntrypoint) - SSOUrl = org.SSOConfig.SSOEntrypoint - break + orgs, err := GetAllOrgs(ctx) + if err == nil { + for _, org := range orgs { + if len(org.ManagerOrgs) == 0 && len(org.SSOConfig.SSOEntrypoint) > 0 { + log.Printf("[INFO] Set initial SSO url for logins to %s", org.SSOConfig.SSOEntrypoint) + SSOUrl = org.SSOConfig.SSOEntrypoint + break + } } + } else { + log.Printf("[WARNING] Error loading orgs: %s", err) } - } else { - log.Printf("[WARNING] Error loading orgs: %s", err) - } */ } } else { diff --git a/shared.go b/shared.go index 903f7105..29ee64e6 100755 --- a/shared.go +++ b/shared.go @@ -111,7 +111,6 @@ func HandleCors(resp http.ResponseWriter, request *http.Request) bool { "https://au.shuffler.io", - "https://jp.shuffler.io", "https://br.shuffler.io", "https://in.shuffler.io", @@ -498,35 +497,40 @@ func HandleSet2fa(resp http.ResponseWriter, request *http.Request) { if len(user.Session) != 0 { log.Printf("[INFO] User session exists - resetting session") + sessionValue := user.Session + decryptedSession, err := HandleKeyDecryption([]byte(sessionValue), "session") + if err == nil { + sessionValue = string(decryptedSession) + } + expiration := time.Now().Add(8 * time.Hour) - newCookie := ConstructSessionCookie(user.Session, expiration) + newCookie := ConstructSessionCookie(sessionValue, expiration) http.SetCookie(resp, newCookie) newCookie.Name = "__session" http.SetCookie(resp, newCookie) - //log.Printf("SESSION LENGTH MORE THAN 0 IN LOGIN: %s", user.Session) returnValue.Cookies = append(returnValue.Cookies, SessionCookie{ Key: "session_token", - Value: user.Session, + Value: sessionValue, Expiration: expiration.Unix(), }) returnValue.Cookies = append(returnValue.Cookies, SessionCookie{ Key: "__session", - Value: user.Session, + Value: sessionValue, Expiration: expiration.Unix(), }) - loginData = fmt.Sprintf(`{"success": true, "cookies": [{"key": "session_token", "value": "%s", "expiration": %d}]}`, user.Session, expiration.Unix()) + loginData = fmt.Sprintf(`{"success": true, "cookies": [{"key": "session_token", "value": "%s", "expiration": %d}]}`, sessionValue, expiration.Unix()) newData, err := json.Marshal(returnValue) if err == nil { loginData = string(newData) } - err = SetSession(ctx, user, user.Session) + err = SetSession(ctx, user, sessionValue) if err != nil { log.Printf("[WARNING] Error adding session to database: %s", err) } else { @@ -3341,6 +3345,8 @@ func HandleApiAuthentication(resp http.ResponseWriter, request *http.Request) (U newApikey = newApikey[0:248] } + log.Printf("[DEBUG] Checking cache without encryption") + cache, err := GetCache(ctx, newApikey+org_id) if err == nil { cacheData := []byte(cache.([]uint8)) @@ -3379,6 +3385,17 @@ func HandleApiAuthentication(resp http.ResponseWriter, request *http.Request) (U return User{}, errors.New("Couldn't find the user") } + // Encrypt API key if matched on plain (userdata.ApiKey == incoming plain key) + uuidRegex := regexp.MustCompile(`^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$`) + if userdata.ApiKey == apikeyCheck[1] && uuidRegex.MatchString(apikeyCheck[1]) { + encryptedKey, err := HandleKeyEncryption([]byte(apikeyCheck[1]), "apikey", true) + if err == nil { + userdata.ApiKey = string(encryptedKey) + SetApikey(ctx, userdata) + SetUser(ctx, &userdata, false) + } + } + // Caching both bad and good apikeys :) if len(org_id) > 0 && userdata.ActiveOrg.Id != org_id { found := false @@ -3543,6 +3560,32 @@ func HandleApiAuthentication(resp http.ResponseWriter, request *http.Request) (U user.SessionLogin = true + uuidRegex := regexp.MustCompile(`^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$`) + + // Encrypt API key if it's plain UUID + if uuidRegex.MatchString(user.ApiKey) { + log.Printf("[AUDIT] API key is a UUID: %s", user.ApiKey) + + encryptedKey, err := HandleKeyEncryption([]byte(user.ApiKey), "apikey", true) + if err == nil { + user.ApiKey = string(encryptedKey) + SetApikey(ctx, user) + SetUser(ctx, &user, false) + } + } + + // Encrypt session if matched on plain + if user.Session == sessionToken && uuidRegex.MatchString(sessionToken) { + log.Printf("[AUDIT] Encrypting session") + encryptedSession, err := HandleKeyEncryption([]byte(sessionToken), "session", true) + if err == nil { + user.Session = string(encryptedSession) + SetSession(ctx, user, user.Session) + } else { + log.Printf("[ERROR] Failed to encrypt session: %v", err) + } + } + // Means session exists, but return user, nil } @@ -9229,11 +9272,20 @@ func HandleSettings(resp http.ResponseWriter, request *http.Request) { return } + apikey := userInfo.ApiKey + uuidRegex := regexp.MustCompile(`^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$`) + if !uuidRegex.MatchString(apikey) { + decrypted, err := HandleKeyDecryption([]byte(apikey), "apikey") + if err == nil { + apikey = string(decrypted) + } + } + newObject := SettingsReturn{ Success: true, Username: userInfo.Username, Verified: userInfo.Verified, - Apikey: userInfo.ApiKey, + Apikey: apikey, Image: userInfo.PublicProfile.GithubAvatar, } @@ -18813,7 +18865,7 @@ func create32Hash(key string) ([]byte, error) { return []byte(hex.EncodeToString(hasher.Sum(nil))), nil } -func HandleKeyEncryption(data []byte, passphrase string) ([]byte, error) { +func HandleKeyEncryption(data []byte, passphrase string, deterministic ...bool) ([]byte, error) { key, err := create32Hash(passphrase) if err != nil { log.Printf("[WARNING] Skipped hashing in encrypt: %s", err) @@ -18832,10 +18884,18 @@ func HandleKeyEncryption(data []byte, passphrase string) ([]byte, error) { return []byte{}, err } - nonce := make([]byte, gcm.NonceSize()) - if _, err = io.ReadFull(rand.Reader, nonce); err != nil { - log.Printf("[WARNING] Error reading GCM nonce: %s", err) - return []byte{}, err + var nonce []byte + if len(deterministic) > 0 && deterministic[0] { + // Deterministic mode: derive nonce from data + passphrase for repeatable encryption + nonceSource := md5.Sum(append(data, []byte(passphrase)...)) + nonce = nonceSource[:gcm.NonceSize()] + } else { + // Random nonce (default behavior) + nonce = make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + log.Printf("[WARNING] Error reading GCM nonce: %s", err) + return []byte{}, err + } } ciphertext := gcm.Seal(nonce, nonce, data, nil) diff --git a/shared_test.go b/shared_test.go index 812db779..8ba1d546 100644 --- a/shared_test.go +++ b/shared_test.go @@ -1,7 +1,7 @@ package shuffle import ( - "net/http" + "os" "testing" ) @@ -28,16 +28,187 @@ func TestIsLoop(t *testing.T) { } } -// Simple Test for HandleInternalProxy(client) -// set env SHUFFLE_INTERNAL_HTTP_PROXY to test the function. -func TestHandleInternalProxy(t *testing.T) { - client := &http.Client{} - result := HandleInternalProxy(client) +// tests that deterministic mode produces consistent output +func TestHandleKeyEncryptionDeterministic(t *testing.T) { + os.Setenv("SHUFFLE_ENCRYPTION_MODIFIER", "test-modifier-12345") + defer os.Unsetenv("SHUFFLE_ENCRYPTION_MODIFIER") - if result.Transport.(*http.Transport).Proxy != nil { - proxyURL, _ := result.Transport.(*http.Transport).Proxy(nil) - t.Logf("Proxy URL set: %v", proxyURL) - } else { - t.Log("No proxy set") + testData := []byte("test-api-key-12345") + passphrase := "apikey" + + encrypted1, err := HandleKeyEncryption(testData, passphrase, true) + if err != nil { + t.Fatalf("First encryption failed: %v", err) + } + + encrypted2, err := HandleKeyEncryption(testData, passphrase, true) + if err != nil { + t.Fatalf("Second encryption failed: %v", err) + } + + if string(encrypted1) != string(encrypted2) { + t.Errorf("Deterministic encryption produced different outputs:\n First: %s\n Second: %s", encrypted1, encrypted2) + } +} + +// tests that default mode produces different output each time +func TestHandleKeyEncryptionRandomNonce(t *testing.T) { + os.Setenv("SHUFFLE_ENCRYPTION_MODIFIER", "test-modifier-12345") + defer os.Unsetenv("SHUFFLE_ENCRYPTION_MODIFIER") + + testData := []byte("test-api-key-12345") + passphrase := "apikey" + + encrypted1, err := HandleKeyEncryption(testData, passphrase) + if err != nil { + t.Fatalf("First encryption failed: %v", err) + } + + encrypted2, err := HandleKeyEncryption(testData, passphrase) + if err != nil { + t.Fatalf("Second encryption failed: %v", err) + } + + if string(encrypted1) == string(encrypted2) { + t.Errorf("Random nonce encryption should produce different outputs, but got same") + } +} + +// tests that encrypted data can be decrypted correctly +func TestHandleKeyDecryption(t *testing.T) { + os.Setenv("SHUFFLE_ENCRYPTION_MODIFIER", "test-modifier-12345") + defer os.Unsetenv("SHUFFLE_ENCRYPTION_MODIFIER") + + testCases := []struct { + name string + data string + passphrase string + deterministic bool + }{ + {"API key deterministic", "abc123-api-key-uuid", "apikey", true}, + {"Session deterministic", "xyz789-session-uuid", "session", true}, + {"API key random nonce", "abc123-api-key-uuid", "apikey", false}, + {"Session random nonce", "xyz789-session-uuid", "session", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var encrypted []byte + var err error + + if tc.deterministic { + encrypted, err = HandleKeyEncryption([]byte(tc.data), tc.passphrase, true) + } else { + encrypted, err = HandleKeyEncryption([]byte(tc.data), tc.passphrase) + } + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + decrypted, err := HandleKeyDecryption(encrypted, tc.passphrase) + if err != nil { + t.Fatalf("Decryption failed: %v", err) + } + + if string(decrypted) != tc.data { + t.Errorf("Decrypted data mismatch:\n Expected: %s\n Got: %s", tc.data, decrypted) + } + }) } } + +// tests that encryption fails without SHUFFLE_ENCRYPTION_MODIFIER +func TestHandleKeyEncryptionNoModifier(t *testing.T) { + os.Unsetenv("SHUFFLE_ENCRYPTION_MODIFIER") + + _, err := HandleKeyEncryption([]byte("test-data"), "passphrase", true) + if err == nil { + t.Error("Expected error when SHUFFLE_ENCRYPTION_MODIFIER is not set, but got none") + } +} + +// simulates storing and retrieving an API key +func TestApiKeyEncryptionRoundTrip(t *testing.T) { + os.Setenv("SHUFFLE_ENCRYPTION_MODIFIER", "test-modifier-12345") + defer os.Unsetenv("SHUFFLE_ENCRYPTION_MODIFIER") + + plainApiKey := "550e8400-e29b-41d4-a716-446655440000" + + encryptedKey, err := HandleKeyEncryption([]byte(plainApiKey), "apikey", true) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + encryptedIncoming, err := HandleKeyEncryption([]byte(plainApiKey), "apikey", true) + if err != nil { + t.Fatalf("Encryption of incoming key failed: %v", err) + } + + if string(encryptedKey) != string(encryptedIncoming) { + t.Errorf("Encrypted keys should match for same input:\n Stored: %s\n Incoming: %s", encryptedKey, encryptedIncoming) + } + + decrypted, err := HandleKeyDecryption(encryptedKey, "apikey") + if err != nil { + t.Fatalf("Decryption failed: %v", err) + } + + if string(decrypted) != plainApiKey { + t.Errorf("Decrypted API key mismatch:\n Expected: %s\n Got: %s", plainApiKey, decrypted) + } +} + +// simulates storing and retrieving a session +func TestSessionEncryptionRoundTrip(t *testing.T) { + os.Setenv("SHUFFLE_ENCRYPTION_MODIFIER", "test-modifier-12345") + defer os.Unsetenv("SHUFFLE_ENCRYPTION_MODIFIER") + + plainSession := "6ba7b810-9dad-11d1-80b4-00c04fd430c8" + + encryptedSession, err := HandleKeyEncryption([]byte(plainSession), "session", true) + if err != nil { + t.Fatalf("Encryption failed: %v", err) + } + + encryptedIncoming, err := HandleKeyEncryption([]byte(plainSession), "session", true) + if err != nil { + t.Fatalf("Encryption of incoming session failed: %v", err) + } + + if string(encryptedSession) != string(encryptedIncoming) { + t.Errorf("Encrypted sessions should match for same input:\n Stored: %s\n Incoming: %s", encryptedSession, encryptedIncoming) + } + + decrypted, err := HandleKeyDecryption(encryptedSession, "session") + if err != nil { + t.Fatalf("Decryption failed: %v", err) + } + + if string(decrypted) != plainSession { + t.Errorf("Decrypted session mismatch:\n Expected: %s\n Got: %s", plainSession, decrypted) + } +} + +// tests that old plain text keys still work +func TestBackwardsCompatibility(t *testing.T) { + os.Setenv("SHUFFLE_ENCRYPTION_MODIFIER", "test-modifier-12345") + defer os.Unsetenv("SHUFFLE_ENCRYPTION_MODIFIER") + + oldPlainApiKey := "old-plain-api-key-uuid" + incomingKey := "old-plain-api-key-uuid" + + encryptedIncoming, _ := HandleKeyEncryption([]byte(incomingKey), "apikey", true) + + // Encrypted version won't match plain text + if string(encryptedIncoming) == oldPlainApiKey { + t.Error("Encrypted key should not match plain text key") + } + + // Plain text comparison works (backwards compat) + if incomingKey != oldPlainApiKey { + t.Error("Plain text comparison should work for backwards compatibility") + } + + t.Logf("Encrypted format: %s", encryptedIncoming) + t.Logf("Plain format: %s", oldPlainApiKey) +}