Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 87 additions & 2 deletions src/interface/web/app/automations/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ interface AutomationsData {
schedule: string;
crontime: string;
next: string;
chat_model_id?: number | null;
}

import cronstrue from "cronstrue";
Expand Down Expand Up @@ -49,6 +50,7 @@ import { useSearchParams } from "next/navigation";
import Link from "next/link";
import { Popover, PopoverTrigger, PopoverContent } from "@/components/ui/popover";
import {
Brain,
CalendarCheck,
CalendarDot,
CalendarDots,
Expand All @@ -62,7 +64,7 @@ import {
Plus,
Trash,
} from "@phosphor-icons/react";
import { useAuthenticatedData, UserProfile } from "../common/auth";
import { ModelOptions, useAuthenticatedData, useChatModelOptions, UserProfile } from "../common/auth";
import LoginPrompt from "../components/loginPrompt/loginPrompt";
import { useToast } from "@/components/ui/use-toast";
import { ToastAction } from "@/components/ui/toast";
Expand Down Expand Up @@ -271,6 +273,7 @@ interface AutomationsCardProps {
setShowLoginPrompt: (showLoginPrompt: boolean) => void;
authenticatedData: UserProfile | null;
setToastMessage: (toastMessage: string) => void;
chatModelOptions?: ModelOptions[];
}

function AutomationsCard(props: AutomationsCardProps) {
Expand Down Expand Up @@ -335,6 +338,7 @@ function AutomationsCard(props: AutomationsCardProps) {
automation={updatedAutomationData || automation}
ipLocationData={props.locationData}
setToastMessage={props.setToastMessage}
chatModelOptions={props.chatModelOptions}
/>
)}
<ShareLink
Expand Down Expand Up @@ -386,7 +390,7 @@ function AutomationsCard(props: AutomationsCardProps) {
{updatedAutomationData?.scheduling_request || automation.scheduling_request}
</CardContent>
<CardFooter className="flex flex-col items-start md:flex-row md:justify-between md:items-center gap-2">
<div className="flex gap-2">
<div className="flex flex-wrap gap-2">
<div className="flex items-center rounded-lg p-1.5 border-blue-200 border dark:border-blue-500">
<CalendarCheck className="h-4 w-4 mr-2 text-blue-700 dark:text-blue-300" />
<div className="text-s text-blue-700 dark:text-blue-300">
Expand All @@ -399,6 +403,21 @@ function AutomationsCard(props: AutomationsCardProps) {
{intervalString}
</div>
</div>
{(() => {
const modelId =
updatedAutomationData?.chat_model_id ?? automation.chat_model_id;
const modelName = props.chatModelOptions?.find(
(m) => m.id === modelId,
)?.name;
return modelName ? (
<div className="flex items-center rounded-lg p-1.5 border-green-200 border dark:border-green-500">
<Brain className="h-4 w-4 mr-2 text-green-700 dark:text-green-300" />
<div className="text-s text-green-700 dark:text-green-300">
{modelName}
</div>
</div>
) : null;
})()}
</div>
{props.suggestedCard && props.setNewAutomationData && (
<AutomationComponentWrapper
Expand All @@ -413,6 +432,7 @@ function AutomationsCard(props: AutomationsCardProps) {
automation={automation}
ipLocationData={props.locationData}
setToastMessage={props.setToastMessage}
chatModelOptions={props.chatModelOptions}
/>
)}
</CardFooter>
Expand All @@ -428,6 +448,7 @@ interface SharedAutomationCardProps {
authenticatedData: UserProfile | null;
isMobileWidth: boolean;
setToastMessage: (toastMessage: string) => void;
chatModelOptions?: ModelOptions[];
}

function SharedAutomationCard(props: SharedAutomationCardProps) {
Expand Down Expand Up @@ -465,6 +486,7 @@ function SharedAutomationCard(props: SharedAutomationCardProps) {
automation={automation}
ipLocationData={props.locationData}
setToastMessage={props.setToastMessage}
chatModelOptions={props.chatModelOptions}
/>
) : null;
}
Expand All @@ -476,6 +498,7 @@ const EditAutomationSchema = z.object({
dayOfMonth: z.optional(z.string()),
timeRecurrence: z.string({ required_error: "Time Recurrence is required" }),
schedulingRequest: z.string({ required_error: "Query to Run is required" }),
chatModelId: z.optional(z.number().nullable()),
});

interface EditCardProps {
Expand All @@ -488,6 +511,7 @@ interface EditCardProps {
setShowLoginPrompt: (showLoginPrompt: boolean) => void;
authenticatedData: UserProfile | null;
setToastMessage: (toastMessage: string) => void;
chatModelOptions?: ModelOptions[];
}

function EditCard(props: EditCardProps) {
Expand All @@ -504,6 +528,7 @@ function EditCard(props: EditCardProps) {
: "12:00 PM",
dayOfMonth: automation?.crontime ? getDayOfMonthFromCron(automation.crontime) : "1",
schedulingRequest: automation?.scheduling_request,
chatModelId: automation?.chat_model_id ?? undefined,
},
});

Expand Down Expand Up @@ -532,6 +557,8 @@ function EditCard(props: EditCardProps) {
updateQueryUrl += `&country=${encodeURIComponent(props.locationData.country)}`;
if (props.locationData && props.locationData.timezone)
updateQueryUrl += `&timezone=${encodeURIComponent(props.locationData.timezone)}`;
if (values.chatModelId != null)
updateQueryUrl += `&chat_model_id=${encodeURIComponent(values.chatModelId)}`;

let method = props.createNew ? "POST" : "PUT";

Expand All @@ -547,6 +574,7 @@ function EditCard(props: EditCardProps) {
schedule: cronToHumanReadableString(data.crontime),
crontime: data.crontime,
next: data.next,
chat_model_id: data.chat_model_id,
});
})
.catch((error) => {
Expand Down Expand Up @@ -603,6 +631,7 @@ function EditCard(props: EditCardProps) {
create={props.createNew}
isLoggedIn={props.isLoggedIn}
setShowLoginPrompt={props.setShowLoginPrompt}
chatModelOptions={props.chatModelOptions}
/>
);
}
Expand All @@ -615,6 +644,7 @@ interface AutomationModificationFormProps {
setShowLoginPrompt: (showLoginPrompt: boolean) => void;
authenticatedData: UserProfile | null;
locationData: LocationData | null;
chatModelOptions?: ModelOptions[];
}

function AutomationModificationForm(props: AutomationModificationFormProps) {
Expand Down Expand Up @@ -860,6 +890,51 @@ function AutomationModificationForm(props: AutomationModificationFormProps) {
</FormItem>
)}
/>
{props.chatModelOptions && props.chatModelOptions.length > 0 && (
<FormField
control={props.form.control}
name="chatModelId"
render={({ field }) => (
<FormItem className="w-full space-y-1">
<FormLabel>Model</FormLabel>
<FormDescription>
Which AI model should this automation use?
</FormDescription>
<Select
onValueChange={(value) =>
field.onChange(
value === "default" ? undefined : Number(value),
)
}
defaultValue={
field.value != null ? String(field.value) : "default"
}
>
<FormControl>
<SelectTrigger className="w-[280px]">
<div className="flex items-center">
<Brain className="h-4 w-4 mr-2 inline" />
</div>
<SelectValue placeholder="Default model" />
</SelectTrigger>
</FormControl>
<SelectContent>
<SelectItem value="default">Default model</SelectItem>
{props.chatModelOptions.map((model) => (
<SelectItem
key={model.id}
value={String(model.id)}
>
{model.name}
</SelectItem>
))}
</SelectContent>
</Select>
<FormMessage />
</FormItem>
)}
/>
)}
<fieldset disabled={isSaving}>
{props.isLoggedIn ? (
isSaving ? (
Expand Down Expand Up @@ -925,6 +1000,7 @@ interface AutomationComponentWrapperProps {
ipLocationData: LocationData | null | undefined;
automation?: AutomationsData;
setToastMessage: (toastMessage: string) => void;
chatModelOptions?: ModelOptions[];
}

function AutomationComponentWrapper(props: AutomationComponentWrapperProps) {
Expand Down Expand Up @@ -953,6 +1029,7 @@ function AutomationComponentWrapper(props: AutomationComponentWrapperProps) {
setUpdatedAutomationData={props.setNewAutomationData}
locationData={props.ipLocationData}
setToastMessage={props.setToastMessage}
chatModelOptions={props.chatModelOptions}
/>
</DrawerContent>
</Drawer>
Expand Down Expand Up @@ -981,6 +1058,7 @@ function AutomationComponentWrapper(props: AutomationComponentWrapperProps) {
setUpdatedAutomationData={props.setNewAutomationData}
locationData={props.ipLocationData}
setToastMessage={props.setToastMessage}
chatModelOptions={props.chatModelOptions}
/>
</DialogContent>
</Dialog>
Expand All @@ -1001,6 +1079,8 @@ export default function Automations() {
revalidateOnFocus: false,
});

const { models: chatModelOptions } = useChatModelOptions();

const [isCreating, setIsCreating] = useState(false);
const [newAutomationData, setNewAutomationData] = useState<AutomationsData | null>(null);
const [allNewAutomations, setAllNewAutomations] = useState<AutomationsData[]>([]);
Expand Down Expand Up @@ -1126,6 +1206,7 @@ export default function Automations() {
isCreating={isCreating}
ipLocationData={locationData}
setToastMessage={setToastMessage}
chatModelOptions={chatModelOptions}
/>
) : (
<Button
Expand All @@ -1147,6 +1228,7 @@ export default function Automations() {
setShowLoginPrompt={setShowLoginPrompt}
setNewAutomationData={setNewAutomationData}
setToastMessage={setToastMessage}
chatModelOptions={chatModelOptions}
/>
</Suspense>
{isLoading && <InlineLoading message="booting up your automations" />}
Expand All @@ -1163,6 +1245,7 @@ export default function Automations() {
isLoggedIn={authenticatedData ? true : false}
setShowLoginPrompt={setShowLoginPrompt}
setToastMessage={setToastMessage}
chatModelOptions={chatModelOptions}
/>
))}
{authenticatedData &&
Expand All @@ -1176,6 +1259,7 @@ export default function Automations() {
isLoggedIn={authenticatedData ? true : false}
setShowLoginPrompt={setShowLoginPrompt}
setToastMessage={setToastMessage}
chatModelOptions={chatModelOptions}
/>
))}
</div>
Expand All @@ -1193,6 +1277,7 @@ export default function Automations() {
setShowLoginPrompt={setShowLoginPrompt}
suggestedCard={true}
setToastMessage={setToastMessage}
chatModelOptions={chatModelOptions}
/>
))}
</div>
Expand Down
14 changes: 13 additions & 1 deletion src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,11 +1716,22 @@ async def aget_conversation_starters(user: KhojUser, max_results=3):
return random.sample(all_questions, max_results)

@staticmethod
async def aget_valid_chat_model(user: KhojUser, conversation: Conversation, is_subscribed: bool):
async def aget_valid_chat_model(
user: KhojUser, conversation: Conversation, is_subscribed: bool, chat_model_id: int = None
):
"""
For paid users: Prefer any custom agent chat model > user default chat model > server default chat model.
For free users: Prefer conversation specific agent's chat model > user default chat model > server default chat model.
An explicit chat_model_id override (e.g. from automations) takes highest priority.
"""
if chat_model_id:
try:
chat_model = await ChatModel.objects.select_related("ai_model_api").aget(pk=chat_model_id)
if chat_model and chat_model.ai_model_api:
return chat_model
except ChatModel.DoesNotExist:
pass

agent: Agent = conversation.agent if await AgentAdapters.aget_default_agent() != conversation.agent else None
if agent and agent.chat_model and (agent.is_hidden or is_subscribed):
chat_model = await ChatModel.objects.select_related("ai_model_api").aget(
Expand Down Expand Up @@ -2209,6 +2220,7 @@ def get_automation_metadata(user: KhojUser, automation: Job):
"schedule": schedule,
"crontime": crontime,
"next": automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z"),
"chat_model_id": automation_metadata.get("chat_model_id"),
}

@staticmethod
Expand Down
19 changes: 17 additions & 2 deletions src/khoj/routers/api_automation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from starlette.authentication import requires

from khoj.database.adapters import AutomationAdapters, ConversationAdapters
from khoj.database.models import KhojUser
from khoj.database.models import ChatModel, KhojUser
from khoj.processor.conversation.utils import clean_json
from khoj.routers.helpers import schedule_automation, schedule_query
from khoj.utils.helpers import is_none_or_empty
Expand Down Expand Up @@ -59,6 +59,7 @@ def post_automation(
region: Optional[str] = None,
country: Optional[str] = None,
timezone: Optional[str] = None,
chat_model_id: Optional[int] = None,
) -> Response:
user: KhojUser = request.user.object

Expand Down Expand Up @@ -95,6 +96,11 @@ def post_automation(
status_code=400,
)

# Validate chat_model_id if provided
if chat_model_id is not None:
if not ChatModel.objects.filter(id=chat_model_id).exists():
return Response(content="Invalid chat model", status_code=400)

# Create new Conversation Session associated with this new task
title = f"Automation: {subject}"
conversation = ConversationAdapters.create_conversation_session(user, request.user.client_app, title=title)
Expand All @@ -104,7 +110,8 @@ def post_automation(
# Use the query to run as the scheduling request if the scheduling request is unset
calling_url = str(request.url.replace(query=f"{request.url.query}"))
automation = schedule_automation(
query_to_run, subject, crontime, timezone, q, user, calling_url, str(conversation.id)
query_to_run, subject, crontime, timezone, q, user, calling_url, str(conversation.id),
chat_model_id=chat_model_id,
)
except Exception as e:
logger.error(f"Error creating automation {q} for {user.email}: {e}", exc_info=True)
Expand Down Expand Up @@ -158,6 +165,7 @@ def edit_job(
region: Optional[str] = None,
country: Optional[str] = None,
timezone: Optional[str] = None,
chat_model_id: Optional[int] = None,
) -> Response:
user: KhojUser = request.user.object

Expand Down Expand Up @@ -199,12 +207,18 @@ def edit_job(
status_code=400,
)

# Validate chat_model_id if provided
if chat_model_id is not None:
if not ChatModel.objects.filter(id=chat_model_id).exists():
return Response(content="Invalid chat model", status_code=400)

# Construct updated automation metadata
automation_metadata: dict[str, str] = json.loads(clean_json(automation.name))
automation_metadata["scheduling_request"] = q
automation_metadata["query_to_run"] = query_to_run
automation_metadata["subject"] = subject.strip()
automation_metadata["crontime"] = crontime
automation_metadata["chat_model_id"] = chat_model_id
conversation_id = automation_metadata.get("conversation_id")

if not conversation_id:
Expand All @@ -226,6 +240,7 @@ def edit_job(
"user": user,
"calling_url": str(request.url),
"conversation_id": conversation_id,
"chat_model_id": chat_model_id,
},
)

Expand Down
1 change: 1 addition & 0 deletions src/khoj/routers/api_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,7 @@ def collect_telemetry():
generated_asset_results,
is_subscribed,
tracer,
chat_model_id=body.chat_model_id,
)

full_response = ""
Expand Down
Loading