Skip to content

chore: Upgrade db to use new modusGraph [DRAFT] #899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion runtime/actors/wasmagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ func (a *wasmAgentActor) saveState(ctx context.Context) error {
Name: a.agentName,
Status: string(a.status),
Data: data,
UpdatedAt: time.Now().UTC().Format(utils.TimeFormat),
UpdatedAt: time.Now(),
}); err != nil {
return fmt.Errorf("error saving agent state to database: %w", err)
}
Expand Down
97 changes: 49 additions & 48 deletions runtime/db/agentstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,23 @@ import (

"github.com/hypermodeinc/modus/runtime/utils"

"github.com/hypermodeinc/modusgraph"
"github.com/jackc/pgx/v5"
)

type AgentState struct {
Gid uint64 `json:"gid,omitempty"`
Id string `json:"id" db:"constraint=unique"`
Name string `json:"name"`
Status string `json:"status"`
Data string `json:"data,omitempty"`
UpdatedAt string `json:"updated"`
Id string `json:"id" dgraph:"index(exact) unique upsert"`
Name string `json:"name" dgraph:"index(exact, term)"`
Status string `json:"status" dgraph:"index(exact)"`
Data string `json:"data,omitempty"`
UpdatedAt time.Time `json:"updated,omitzero" dgraph:"index(hour)"`

UID string `json:"uid"`
DType []string `json:"dgraph.type,omitempty"`
}

func WriteAgentState(ctx context.Context, state AgentState) error {
if useModusDB() {
return writeAgentStateToModusDB(ctx, state)
return writeAgentStateToModusDB(ctx, &state)
} else {
return writeAgentStateToPostgresDB(ctx, state)
}
Expand All @@ -53,81 +54,83 @@ func GetAgentState(ctx context.Context, id string) (*AgentState, error) {
}
}

func QueryActiveAgents(ctx context.Context) ([]AgentState, error) {
func QueryActiveAgents(ctx context.Context) ([]*AgentState, error) {
if useModusDB() {
return queryActiveAgentsFromModusDB(ctx)
} else {
return queryActiveAgentsFromPostgresDB(ctx)
}
}

func writeAgentStateToModusDB(ctx context.Context, state AgentState) error {
// writeAgentStateToModusDB writes an agent state to the modusGraph database.
func writeAgentStateToModusDB(ctx context.Context, state *AgentState) error {
span, ctx := utils.NewSentrySpanForCurrentFunc(ctx)
defer span.Finish()

gid, _, _, err := modusgraph.Upsert(ctx, GlobalModusDbEngine, state)
state.Gid = gid

return err
client, err := GetClient()
if err != nil {
return err
}
return client.Upsert(ctx, state)
}

// updateAgentStatusInModusDB updates the status of an agent in the modusGraph database.
func updateAgentStatusInModusDB(ctx context.Context, id string, status string) error {
span, ctx := utils.NewSentrySpanForCurrentFunc(ctx)
defer span.Finish()

// TODO: this should just be an update in a single operation

state, err := getAgentStateFromModusDB(ctx, id)
if err != nil {
return err
}

state.Status = status
state.UpdatedAt = time.Now().UTC().Format(utils.TimeFormat)
state.UpdatedAt = time.Now().UTC()

return writeAgentStateToModusDB(ctx, *state)
return writeAgentStateToModusDB(ctx, state)
}

// getAgentStateFromModusDB queries the modusGraph database for a specific agent state
// by ID. It returns the agent state if found, or an error if not found.
func getAgentStateFromModusDB(ctx context.Context, id string) (*AgentState, error) {
span, ctx := utils.NewSentrySpanForCurrentFunc(ctx)
defer span.Finish()

_, result, err := modusgraph.Get[AgentState](ctx, GlobalModusDbEngine, modusgraph.ConstrainedField{
Key: "id",
Value: id,
})
client, err := GetClient()
if err != nil {
return nil, fmt.Errorf("failed to query agent state: %w", err)
return nil, err
}

return &result, nil
state := &AgentState{}
err = client.Query(ctx, state).
Filter(`eq(id, $1)`, id).
Node()
if err != nil {
return nil, fmt.Errorf("failed to query agent state: %w", err)
}
return state, nil
}

func queryActiveAgentsFromModusDB(ctx context.Context) ([]AgentState, error) {
// queryActiveAgentsFromModusDB queries the modusGraph database for active agents
// (those with a status other than "terminated"). Agents are ordered by updated time.
// TODO: add pagination support
func queryActiveAgentsFromModusDB(ctx context.Context) ([]*AgentState, error) {
span, ctx := utils.NewSentrySpanForCurrentFunc(ctx)
defer span.Finish()

_, results, err := modusgraph.Query[AgentState](ctx, GlobalModusDbEngine, modusgraph.QueryParams{
Filter: &modusgraph.Filter{
Not: &modusgraph.Filter{
Field: "status",
String: modusgraph.StringPredicate{
Equals: "terminated",
},
},
},
// TODO: Sorting gives a dgraph error. Why?
// Sorting: &modusgraph.Sorting{
// OrderDescField: "updated",
// OrderDescFirst: true,
// },
})

client, err := GetClient()
if err != nil {
return nil, err
}
states := []*AgentState{}
err = client.Query(ctx, &states).
Filter(`not(eq(status, "terminated"))`).
OrderDesc("updated").
Nodes()
if err != nil {
return nil, fmt.Errorf("failed to query agent state: %w", err)
}

return results, nil
return states, nil
}

func writeAgentStateToPostgresDB(ctx context.Context, state AgentState) error {
Expand Down Expand Up @@ -179,7 +182,6 @@ func getAgentStateFromPostgresDB(ctx context.Context, id string) (*AgentState, e
if err := row.Scan(&a.Id, &a.Name, &a.Status, &a.Data, &ts); err != nil {
return fmt.Errorf("failed to get agent state: %w", err)
}
a.UpdatedAt = ts.UTC().Format(utils.TimeFormat)
return nil
})

Expand All @@ -190,14 +192,14 @@ func getAgentStateFromPostgresDB(ctx context.Context, id string) (*AgentState, e
return &a, nil
}

func queryActiveAgentsFromPostgresDB(ctx context.Context) ([]AgentState, error) {
func queryActiveAgentsFromPostgresDB(ctx context.Context) ([]*AgentState, error) {
span, ctx := utils.NewSentrySpanForCurrentFunc(ctx)
defer span.Finish()

const query = "SELECT id, name, status, data, updated FROM agents " +
"WHERE status != 'terminated' ORDER BY updated DESC"

results := make([]AgentState, 0)
results := make([]*AgentState, 0)
err := WithTx(ctx, func(tx pgx.Tx) error {
rows, err := tx.Query(ctx, query)
if err != nil {
Expand All @@ -210,8 +212,7 @@ func queryActiveAgentsFromPostgresDB(ctx context.Context) ([]AgentState, error)
if err := rows.Scan(&a.Id, &a.Name, &a.Status, &a.Data, &ts); err != nil {
return err
}
a.UpdatedAt = ts.UTC().Format(utils.TimeFormat)
results = append(results, a)
results = append(results, &a)
}
if err := rows.Err(); err != nil {
return err
Expand Down
88 changes: 88 additions & 0 deletions runtime/db/agentstate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package db_test

import (
"context"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/hypermodeinc/modus/runtime/app"
"github.com/hypermodeinc/modus/runtime/db"
)

func setupModusDBTest(t *testing.T) (context.Context, func()) {
os.Setenv("MODUS_USE_MODUSDB", "true")
ctx := context.Background()
tmpDir := t.TempDir()
cfg := app.NewAppConfig().WithAppPath(tmpDir)
app.SetConfig(cfg)
db.InitModusDb(ctx, "")
cleanup := func() {
db.CloseModusDb(ctx)
}
return ctx, cleanup
}

func TestAgentStateModusGraph(t *testing.T) {
ctx, cleanup := setupModusDBTest(t)
defer cleanup()

// 1. Write
agent := db.AgentState{
Id: "agent-123",
Name: "TestAgent",
Status: "active",
Data: "{\"foo\":\"bar\"}",
UpdatedAt: time.Now().UTC().Truncate(time.Second),
DType: []string{"AgentState"},
}
err := db.WriteAgentState(ctx, agent)
require.NoError(t, err, "WriteAgentState should succeed")

// 2. Get
got, err := db.GetAgentState(ctx, agent.Id)
require.NoError(t, err, "GetAgentState should succeed")
require.NotNil(t, got, "Returned AgentState should not be nil")
assert.Equal(t, agent.Id, got.Id, "AgentState ID should match")
assert.Equal(t, agent.Name, got.Name, "AgentState Name should match")
assert.Equal(t, agent.Status, got.Status, "AgentState Status should match")
assert.Equal(t, agent.Data, got.Data, "AgentState Data should match")
assert.Equal(t, agent.UpdatedAt, got.UpdatedAt, "AgentState UpdatedAt should match")

// 3. Update
err = db.UpdateAgentStatus(ctx, agent.Id, "terminated")
require.NoError(t, err, "UpdateAgentStatus should succeed")
got, err = db.GetAgentState(ctx, agent.Id)
require.NoError(t, err, "GetAgentState should succeed")
assert.Equal(t, "terminated", got.Status, "AgentState Status should match")
assert.Equal(t, agent.UpdatedAt, got.UpdatedAt, "AgentState UpdatedAt should match")

// 4. Query
agents, err := db.QueryActiveAgents(ctx)
require.NoError(t, err, "QueryActiveAgents should succeed")
found := false
for _, a := range agents {
if a.Id == agent.Id {
found = true
break
}
}
assert.False(t, found, "Agent should not be active after status update to terminated")

// Set status back to active and query again
err = db.UpdateAgentStatus(ctx, agent.Id, "active")
require.NoError(t, err, "UpdateAgentStatus should succeed")
agents, err = db.QueryActiveAgents(ctx)
require.NoError(t, err, "QueryActiveAgents should succeed")
found = false
for _, a := range agents {
if a.Id == agent.Id {
found = true
break
}
}
assert.True(t, found, "Agent should be active after status update to active")
}
68 changes: 68 additions & 0 deletions runtime/db/log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package db

import (
"github.com/go-logr/logr"
"github.com/rs/zerolog"
)

type Zerologr struct {
logger zerolog.Logger
level int
}

// NewZerologr wraps a zerolog.Logger into a logr.Logger
func NewZerologr(l zerolog.Logger) logr.Logger {
return logr.New(&Zerologr{logger: l, level: 0})
}

// Implement the logr.LogSink interface

func (z *Zerologr) Init(info logr.RuntimeInfo) {
// No initialization needed
}

func (z *Zerologr) Enabled(level int) bool {
return level <= z.level
}

func (z *Zerologr) Info(level int, msg string, keysAndValues ...interface{}) {
if !z.Enabled(level) {
return
}
evt := z.logger.Info()
for i := 0; i < len(keysAndValues); i += 2 {
key, val := keysAndValues[i], keysAndValues[i+1]
evt.Interface(key.(string), val)
}
evt.Msg(msg)
}

func (z *Zerologr) Error(err error, msg string, keysAndValues ...interface{}) {
evt := z.logger.Error().Err(err)
for i := 0; i < len(keysAndValues); i += 2 {
key, val := keysAndValues[i], keysAndValues[i+1]
evt.Interface(key.(string), val)
}
evt.Msg(msg)
}

func (z *Zerologr) WithValues(keysAndValues ...interface{}) logr.LogSink {
return &Zerologr{
logger: z.logger.With().Fields(keysAndValues).Logger(),
level: z.level,
}
}

func (z *Zerologr) WithName(name string) logr.LogSink {
return &Zerologr{
logger: z.logger.With().Str("logger", name).Logger(),
level: z.level,
}
}

func (z *Zerologr) V(level int) logr.LogSink {
return &Zerologr{
logger: z.logger,
level: level,
}
}
Loading
Loading