Skip to content

Commit 39fe023

Browse files
JasonWeillpre-commit-ci[bot]dlqqq
authored
Validate JSON for request schema (#261)
* Corrects capitalization for SageMaker endpoint * WIP: Pass expected format in model for field * Validates JSON using JSON.parse * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP: Validate JSON in magics * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix reference to error * Update packages/jupyter-ai-magics/jupyter_ai_magics/magics.py Co-authored-by: david qiu <[email protected]> * Avoids redundant parameter --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: david qiu <[email protected]>
1 parent ef2f341 commit 39fe023

File tree

4 files changed

+49
-19
lines changed

4 files changed

+49
-19
lines changed

packages/jupyter-ai-magics/jupyter_ai_magics/magics.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,15 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
530530
provider_params["request_schema"] = args.request_schema
531531
provider_params["response_path"] = args.response_path
532532

533+
# Validate that the request schema is well-formed JSON
534+
try:
535+
json.loads(args.request_schema)
536+
except json.JSONDecodeError as e:
537+
raise ValueError(
538+
"request-schema must be valid JSON. "
539+
f"Error at line {e.lineno}, column {e.colno}: {e.msg}"
540+
) from None
541+
533542
provider = Provider(**provider_params)
534543

535544
# generate output from model via provider

packages/jupyter-ai-magics/jupyter_ai_magics/providers.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,19 @@ class AwsAuthStrategy(BaseModel):
5454
]
5555

5656

57-
class TextField(BaseModel):
58-
type: Literal["text"] = "text"
57+
class Field(BaseModel):
5958
key: str
6059
label: str
60+
# "text" accepts any text
61+
format: Literal["json", "jsonpath", "text"]
62+
63+
64+
class TextField(Field):
65+
type: Literal["text"] = "text"
6166

6267

63-
class MultilineTextField(BaseModel):
68+
class MultilineTextField(Field):
6469
type: Literal["text-multiline"] = "text-multiline"
65-
key: str
66-
label: str
6770

6871

6972
Field = Union[TextField, MultilineTextField]
@@ -393,7 +396,7 @@ def transform_output(self, output: bytes) -> str:
393396

394397
class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
395398
id = "sagemaker-endpoint"
396-
name = "Sagemaker Endpoint"
399+
name = "SageMaker endpoint"
397400
models = ["*"]
398401
model_id_key = "endpoint_name"
399402
# This all needs to be on one line of markdown, for use in a table
@@ -408,18 +411,9 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
408411
auth_strategy = AwsAuthStrategy()
409412
registry = True
410413
fields = [
411-
TextField(
412-
key="region_name",
413-
label="Region name",
414-
),
415-
MultilineTextField(
416-
key="request_schema",
417-
label="Request schema",
418-
),
419-
TextField(
420-
key="response_path",
421-
label="Response path",
422-
),
414+
TextField(key="region_name", label="Region name", format="text"),
415+
MultilineTextField(key="request_schema", label="Request schema", format="json"),
416+
TextField(key="response_path", label="Response path", format="jsonpath"),
423417
]
424418

425419
def __init__(self, *args, **kwargs):

packages/jupyter-ai/src/components/settings/model-fields.tsx

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import React from 'react';
1+
import React, { useState } from 'react';
22
import { AiService } from '../../handler';
33
import { TextField } from '@mui/material';
44

@@ -13,9 +13,30 @@ export type ModelFieldProps = {
1313
};
1414

1515
export function ModelField(props: ModelFieldProps): JSX.Element {
16+
const [errorMessage, setErrorMessage] = useState<string | null>(null);
17+
1618
function handleChange(
1719
e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement>
1820
) {
21+
// Perform validation based on the field format
22+
switch (props.field.format) {
23+
case 'json':
24+
try {
25+
// JSON.parse does not allow single quotes or trailing commas
26+
JSON.parse(e.target.value);
27+
setErrorMessage(null);
28+
} catch (exc) {
29+
setErrorMessage('You must specify a value in JSON format.');
30+
}
31+
break;
32+
case 'jsonpath':
33+
// TODO: Do JSONPath validation
34+
break;
35+
default:
36+
// No validation performed
37+
break;
38+
}
39+
1940
props.setConfig({
2041
...props.config,
2142
fields: {
@@ -34,6 +55,8 @@ export function ModelField(props: ModelFieldProps): JSX.Element {
3455
label={props.field.label}
3556
value={props.config.fields[props.gmid]?.[props.field.key]}
3657
onChange={handleChange}
58+
error={!!errorMessage}
59+
helperText={errorMessage ?? undefined}
3760
fullWidth
3861
/>
3962
);
@@ -47,6 +70,8 @@ export function ModelField(props: ModelFieldProps): JSX.Element {
4770
onChange={handleChange}
4871
fullWidth
4972
multiline
73+
error={!!errorMessage}
74+
helperText={errorMessage ?? undefined}
5075
minRows={2}
5176
/>
5277
);

packages/jupyter-ai/src/handler.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,14 @@ export namespace AiService {
135135
type: 'text';
136136
key: string;
137137
label: string;
138+
format: string;
138139
};
139140

140141
export type MultilineTextField = {
141142
type: 'text-multiline';
142143
key: string;
143144
label: string;
145+
format: string;
144146
};
145147

146148
export type Field = TextField | MultilineTextField;

0 commit comments

Comments
 (0)