Skip to content

Commit 22fae91

Browse files
ewilliams-clouderajkwatsonbaasitsharief
authored
Mob/main (#287)
* allow an optional env var for the postgres admin role * don't always check for CAII enablement * refactor: optimize streaming chat mutation handlers with useCallback and useMemo * Revert "refactor: optimize streaming chat mutation handlers with useCallback and useMemo" This reverts commit 691b10f. * potential solution for queueing chunks * custom hook for buffering chunks * fix summary reconciler * bump mui packages * remove comment in test * rename file as no longer private * fix imports * remove import --------- Co-authored-by: jwatson <[email protected]> Co-authored-by: Baasit Sharief <[email protected]>
1 parent a2e81a9 commit 22fae91

File tree

16 files changed

+1166
-334
lines changed

16 files changed

+1166
-334
lines changed

backend/src/main/java/com/cloudera/cai/rag/configuration/JdbiConfiguration.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import com.cloudera.cai.util.db.RdbConfig;
4444
import com.cloudera.cai.util.db.migration.Migrator;
4545
import java.sql.SQLException;
46+
import java.util.Optional;
4647
import javax.sql.DataSource;
4748
import lombok.extern.slf4j.Slf4j;
4849
import org.jdbi.v3.core.HandleCallback;
@@ -105,7 +106,11 @@ private static DatabaseConfig createDatabaseConfig() {
105106
.build();
106107
if (rdbConfiguration.isPostgres()) {
107108
rdbConfiguration =
108-
rdbConfiguration.toBuilder().rdbUsername("postgres").rdbDatabaseName(null).build();
109+
rdbConfiguration.toBuilder()
110+
.rdbUsername(
111+
Optional.ofNullable(System.getenv("POSTGRES_ADMIN_ROLE")).orElse("postgres"))
112+
.rdbDatabaseName(null)
113+
.build();
109114
}
110115
return DatabaseConfig.builder().RdbConfiguration(rdbConfiguration).build();
111116
}

backend/src/main/java/com/cloudera/cai/rag/files/RagFileService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ private RagDocumentMetadata processFile(
151151
log.info("Saved document with id: {}", id);
152152

153153
ragFileIndexReconciler.submit(ragDocument.withId(id));
154-
ragFileSummaryReconciler.submit(ragDocument.withId(id));
154+
ragFileSummaryReconciler.resync();
155155

156156
return new RagDocumentMetadata(
157157
ragDocument.filename(), documentId, ragDocument.extension(), ragDocument.sizeInBytes());

backend/src/main/java/com/cloudera/cai/rag/files/RagFileSummaryReconciler.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ public void resync() {
9191
SELECT rdsd.* from rag_data_source_document rdsd
9292
JOIN rag_data_source rds ON rdsd.data_source_id = rds.id
9393
WHERE rdsd.summary_creation_timestamp IS NULL
94-
AND rdsd.VECTOR_UPLOAD_TIMESTAMP IS NOT NULL
9594
AND (rdsd.deleted IS NULL OR rdsd.deleted = :deleted)
9695
AND (rdsd.time_created > :yesterday OR rds.time_updated > :yesterday)
9796
AND rds.summarization_model IS NOT NULL AND rds.summarization_model != ''

llm-service/app/ai/indexing/summary_indexer.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@
7575
from ..vector_stores.vector_store_factory import VectorStoreFactory
7676
from ...config import settings
7777
from ...services.metadata_apis import data_sources_metadata_api
78-
from ...services.models.providers import CAIIModelProvider, AzureModelProvider, OpenAiModelProvider
78+
from ...services.models.providers import ModelProvider
79+
from ...services.models import ModelSource
7980

8081
logger = logging.getLogger(__name__)
8182

@@ -132,11 +133,16 @@ def __index_configuration(
132133
embed_summaries: bool = True,
133134
) -> Dict[str, Any]:
134135
prompt_helper: Optional[PromptHelper] = None
135-
# if we're using CAII, let's be conservative and use a small context window to account for mistral's small context
136-
if CAIIModelProvider.is_enabled():
136+
model_source: ModelSource = (
137+
ModelProvider.get_provider_class().get_model_source()
138+
)
139+
if model_source == "CAII":
140+
# if we're using CAII, let's be conservative and use a small context window to account for mistral's small context
137141
prompt_helper = PromptHelper(context_window=3000)
138-
if AzureModelProvider.is_enabled() or OpenAiModelProvider.is_enabled():
139-
prompt_helper = PromptHelper(context_window=min(llm.metadata.context_window, 10000))
142+
if model_source == "Azure" or model_source == "OpenAI":
143+
prompt_helper = PromptHelper(
144+
context_window=min(llm.metadata.context_window, 10000)
145+
)
140146
return {
141147
"llm": llm,
142148
"response_synthesizer": get_response_synthesizer(

llm-service/app/services/amp_metadata/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
AzureModelProvider,
5959
BedrockModelProvider,
6060
)
61+
from app.services.utils import timed_lru_cache
6162

6263

6364
class AwsConfig(BaseModel):
@@ -286,7 +287,9 @@ def validate_model_config(environ: dict[str, str]) -> ValidationResult:
286287
return ValidationResult(valid=valid_model_config_exists, message=message)
287288

288289

289-
def validate(environ: dict[str, str]) -> ConfigValidationResults:
290+
@timed_lru_cache(maxsize=1, seconds=6000)
291+
def validate(frozen_env: frozenset[tuple[str, str]]) -> ConfigValidationResults:
292+
environ = {k: v for k, v in frozen_env}
290293
print("Validating environment variables...")
291294
storage_config = validate_storage_config(environ)
292295
model_config = validate_model_config(environ)
@@ -382,7 +385,7 @@ def build_configuration(
382385
opensearch_endpoint=env.get("OPENSEARCH_ENDPOINT"),
383386
opensearch_namespace=env.get("OPENSEARCH_NAMESPACE"),
384387
)
385-
validate_config = validate(env)
388+
validate_config = validate(frozenset(env.items()))
386389

387390
model_provider = (
388391
TypeAdapter(ModelProviderType).validate_python(env.get("MODEL_PROVIDER"))

llm-service/app/services/models/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@
3535
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636
# DATA.
3737
#
38-
from ._model_source import ModelSource
3938
from .embedding import Embedding
4039
from .llm import LLM
41-
from .providers._model_provider import ModelProvider
40+
from .providers import ModelProvider
4241
from .reranking import Reranking
42+
from ._model_source import ModelSource
4343

44-
__all__ = ["Embedding", "LLM", "Reranking"]
44+
__all__ = ["Embedding", "LLM", "Reranking", "ModelSource"]
4545

4646

4747
def get_model_source() -> ModelSource:

llm-service/app/services/models/providers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,12 @@
3939
from .bedrock import BedrockModelProvider
4040
from .caii import CAIIModelProvider
4141
from .openai import OpenAiModelProvider
42+
from ._model_provider import ModelProvider
4243

4344
__all__ = [
4445
"AzureModelProvider",
4546
"BedrockModelProvider",
4647
"CAIIModelProvider",
4748
"OpenAiModelProvider",
48-
]
49+
"ModelProvider",
50+
]

ui/package.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
"@emotion/react": "^11.14.0",
2424
"@emotion/styled": "^11.14.0",
2525
"@microsoft/fetch-event-source": "^2.0.1",
26-
"@mui/material": "^6.4.3",
27-
"@mui/x-charts": "^7.26.0",
26+
"@mui/material": "^7.3.1",
27+
"@mui/x-charts": "^8.9.2",
2828
"@tanstack/react-query": "^5.79.0",
2929
"@tanstack/react-query-devtools": "^5.79.0",
30-
"@tanstack/react-router": "^1.120.13",
30+
"@tanstack/react-router": "^1.130.12",
3131
"antd": "^5.25.3",
3232
"date-fns": "^4.1.0",
3333
"lodash": "^4.17.21",

0 commit comments

Comments
 (0)