diff --git a/docs/pipes/llm/index.md b/docs/pipes/llm/index.md index 53ad30809..377f7fa60 100644 --- a/docs/pipes/llm/index.md +++ b/docs/pipes/llm/index.md @@ -10,5 +10,6 @@ to perform various information extraction tasks. | Component | Description | |----------------------------|-----------------------------------------------------------| | `eds.llm_markup_extractor` | Extract structured information using LLMs through markup. | +| `eds.llm_span_qualifier` | Predict attributes of spans using LLMs. | diff --git a/docs/pipes/llm/llm-span-qualifier.md b/docs/pipes/llm/llm-span-qualifier.md new file mode 100644 index 000000000..455757c12 --- /dev/null +++ b/docs/pipes/llm/llm-span-qualifier.md @@ -0,0 +1,8 @@ +# LLM Span Qualifier {: #edsnlp.pipes.llm.llm_span_qualifier.factory.create_component } + +::: edsnlp.pipes.llm.llm_span_qualifier.factory.create_component + options: + heading_level: 3 + show_bases: false + show_source: false + only_class_level: true diff --git a/edsnlp/pipes/__init__.py b/edsnlp/pipes/__init__.py index 7a1baa332..93a502775 100644 --- a/edsnlp/pipes/__init__.py +++ b/edsnlp/pipes/__init__.py @@ -84,4 +84,5 @@ from .trainable.embeddings.text_cnn.factory import create_component as text_cnn from .misc.split import Split as split from .misc.explode import Explode as explode - from .llm.llm_markup_extractor import LlmMarkupExtractor as llm_markup_extractor + from .llm.llm_markup_extractor.factory import create_component as llm_markup_extractor + from .llm.llm_span_qualifier.factory import create_component as llm_span_qualifier diff --git a/edsnlp/pipes/llm/llm_span_qualifier/__init__.py b/edsnlp/pipes/llm/llm_span_qualifier/__init__.py new file mode 100644 index 000000000..7a3024be3 --- /dev/null +++ b/edsnlp/pipes/llm/llm_span_qualifier/__init__.py @@ -0,0 +1 @@ +from .llm_span_qualifier import LlmSpanQualifier diff --git a/edsnlp/pipes/llm/llm_span_qualifier/factory.py b/edsnlp/pipes/llm/llm_span_qualifier/factory.py new file mode 100644 index 000000000..00a8f5e5e --- /dev/null +++ b/edsnlp/pipes/llm/llm_span_qualifier/factory.py @@ -0,0 +1,8 @@ +from edsnlp import registry + +from .llm_span_qualifier import LlmSpanQualifier + +create_component = registry.factory.register( + "eds.llm_span_qualifier", + assigns=["doc.ents", "doc.spans"], +)(LlmSpanQualifier) diff --git a/edsnlp/pipes/llm/llm_span_qualifier/llm_span_qualifier.py b/edsnlp/pipes/llm/llm_span_qualifier/llm_span_qualifier.py new file mode 100644 index 000000000..ea94e390e --- /dev/null +++ b/edsnlp/pipes/llm/llm_span_qualifier/llm_span_qualifier.py @@ -0,0 +1,758 @@ +import json +import os +import warnings +from typing import ( + Any, + Callable, + Coroutine, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + Union, +) + +from pydantic import BaseModel +from spacy.tokens import Doc, Span +from typing_extensions import Annotated, Literal + +from edsnlp.core import PipelineProtocol +from edsnlp.pipes.base import BaseSpanAttributeClassifierComponent +from edsnlp.utils.bindings import BINDING_GETTERS, BINDING_SETTERS, AttributesArg +from edsnlp.utils.span_getters import ContextWindow, SpanGetterArg, get_spans + +from ..async_worker import AsyncRequestWorker + + +class LlmSpanQualifier(BaseSpanAttributeClassifierComponent): + r''' + The `eds.llm_span_qualifier` component qualifies spans using a + Large Language Model (LLM) that returns structured JSON attributes. + + This component takes existing spans, wraps them with `` markers inside a + context window and prompts an LLM to answer with a JSON object that matches the + configured schema. The response is validated and written back on the span + extensions. + + In practice, along with a system prompt that constrains the allowed attributes + and optional few-shot examples provided as previous user / assistant messages, + the component sends snippets such as: + ``` + Biopsies du 12/02/2025 : adénocarcinome. + ``` + + and expects a minimal JSON answer, for example: + ```json + {"biopsy_procedure": "yes"} + ``` + which is then parsed and assigned to the span attributes. + + + !!! warning "Experimental" + + This component is experimental. The API and behavior may change in future + versions. Make sure to pin your `edsnlp` version if you use it in a project. + + !!! note "Dependencies" + + This component requires several dependencies. Run the following command to + install them: + ```bash { data-md-color-scheme="slate" } + pip install openai bm25s Stemmer + ``` + We recommend even to add them to your `pyproject.toml` or `requirements.txt`. + + Examples + -------- + If your data is sensitive, we recommend you to use a self-hosted + model with an OpenAI-compatible API, such as + [vLLM](https://github.com/vllm-project/vllm). + + Start a server with the model of your choice: + + ```bash { data-md-color-scheme="slate" } + python -m vllm.entrypoints.openai.api_server \ + --model mistral-small-24b-instruct-2501 \ + --port 8080 \ + --enable-prefix-caching + ``` + + You can then use the `llm_span_qualifier` component as follows: + + + + === "Yes/no bool classification" + + ```python { .no-check } + from typing import Annotated + from pydantic import BeforeValidator, PlainSerializer, WithJsonSchema + import edsnlp, edsnlp.pipes as eds + + BiopsySchema = Annotated[ + bool, + BeforeValidator(lambda v: str(v).lower() in {"yes", "y", "true"}), + PlainSerializer(lambda v: "yes" if v else "no", when_used="json"), + ] + + PROMPT = """ + You are a span classifier. The user sends text where the target is + marked with .... Answer ONLY with a JSON value: "yes" or + "no" indicating whether the span is a biopsy date. + """.strip() + + nlp = edsnlp.blank("eds") + nlp.add_pipe(eds.sentences()) + nlp.add_pipe(eds.dates(span_setter="ents")) + + # EDS-NLP util to create documents from Markdown or XML markup. + # This has nothing to do with the LLM component itself. The following + # will create docs with entities labelled "date", store them in doc.ents, + # and set their span._.biopsy_procedure attribute. + examples = list(edsnlp.data.from_iterable( + [ + "IRM du 10/02/2025. Biopsies du 12/02/2025 : adénocarcinome.", + "Chirurgie le 24/12/2021. Colectomie. Consultation du 26/12/2021.", + ], + converter="markup", + preset="xml", + ).map(nlp.pipes.sentences)) + + doc_to_xml = edsnlp.data.converters.DocToMarkupConverter(preset="xml") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="mistral-small-24b-instruct-2501", + prompt=PROMPT, + span_getter="ents", + context_getter="sent", + context_formatter=doc_to_xml, + attributes=["biopsy_procedure"], + output_schema=BiopsySchema, + examples=examples, + max_few_shot_examples=2, + max_concurrent_requests=4, + seed=0, + ) + ) + + text = """ + RCP Prostate – 20/02/2025 + Biopsies du 12/02/2025 : adénocarcinome Gleason 4+4=8. + Simulation scanner le 25/02/2025. + """ + doc = nlp(text) + for d in doc.ents: + print(d.text, "→ biopsy_procedure:", d._.biopsy_procedure) + # Out: 20/02/2025 → biopsy_procedure: False + # Out: 12/02/2025 → biopsy_procedure: True + # Out: 25/02/2025 → biopsy_procedure: False + ``` + + === "Multi-attribute classification" + + ```python { .no-check } + from typing import Annotated, Optional + import datetime + from pydantic import BaseModel, Field + import edsnlp, edsnlp.pipes as eds + + # Pydantic schema used to validate the LLM response, serialize the + # few-shot example answers constrain the model output. + class CovidMentionSchema(BaseModel): + negation: bool = Field(..., description="Is the span negated or not") + date: Optional[datetime.date] = Field( + None, description="Date associated with the span, if any" + ) + + PROMPT = """ + You are a span classifier. For every piece of markup-annotated text the + user provides, you predict the attributes of the annotated spans. + You must follow these rules strictly: + - Be consistent, similar queries must lead to similar answers. + - Do not add any comment or explanation, just provide the answer. + Example with a negation and a date: + User: "Le 1er mai 2024, le patient a été testé covid négatif" + Assistant: "{"negation": true, "date": "2024-05-01"}" + For each span, provide a JSON with a "negation" boolean attribute, set to + true if the span is negated, false otherwise. If a date is associated with + the span, provide it as a "date" attribute in ISO format (YYYY-MM-DD). + """.strip() + + nlp = edsnlp.blank("eds") + nlp.add_pipe(eds.sentences()) + nlp.add_pipe(eds.covid()) + + # EDS-NLP util to create documents from Markdown or XML markup. + # This has nothing to do with the LLM component itself. + examples = list(edsnlp.data.from_iterable( + [ + "Covid positif le 1er mai 2024.", + "Pas de covid", + # ... add more examples if you can + ], + converter="markup", preset="xml", + ).map(nlp.pipes.sentences)) + + doc_to_xml = edsnlp.data.converters.DocToMarkupConverter(preset="xml") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="https://api.openai.com/v1", + model="gpt-5-mini", + prompt=PROMPT, + span_getter="ents", + context_getter="words[-10:10]", + context_formatter=doc_to_xml, + output_schema=CovidMentionSchema, + examples=examples, + max_few_shot_examples=1, + max_concurrent_requests=4, + seed=0, + ) + ) + doc = nlp("Pas d'indication de covid le 3 mai 2024.") + (ent,) = doc.ents + print(ent.text, "→ negation:", ent._.negation, "date:", ent._.date) + # Out: covid → negation: True date: 2024-05-03 + ``` + + + + You can also control the prompt more finely by providing a callable instead of a + string. For example, to put few-shot examples in the system message and keep the + span context as the user payload: + + ```python { .no-check } + # Use this for the `prompt` argument instead of PROMPT above + def prompt(context_text, examples): + messages = [] + system_content = ( + "You are a span classifier.\n" + "Answer with JSON using the keys: biopsy_procedure.\n" + "Here are some examples:\n" + ) + for ex_context, ex_json in examples: + system_content += f"- Context: {ex_context}\n" + system_content += f" JSON: {ex_json}\n" + messages.append({"role": "system", "content": system_content}) + messages.append({"role": "user", "content": context_text}) + return messages + ``` + + Parameters + ---------- + nlp : PipelineProtocol + Pipeline object. + name : str + Component name. + api_url : str + Base URL of the OpenAI-compatible API. + model : str + Model identifier exposed by the API. + prompt : Union[str, Callable[[str, List[Tuple[str, str]]], List[Dict[str, str]]]] + The prompt is the main way to control the model's behavior. + It can be either: + + - A string, which will be used as a system prompt. + Few-shot examples (if any) will be provided as user/assistant + messages before the actual user query. + - A callable that takes two arguments and returns a list of messages in the + format expected by the OpenAI chat completions API. + + * `context`: the context text with the target span marked up + * `examples`: a list of few-shot examples, each being a tuple of + (context, answer) + span_getter : Optional[SpanGetterArg] + Spans to classify. Defaults to `{"ents": True}`. + context_getter : Optional[ContextWindow] + Optional context window specification (e.g. `"sent"`, `"words[-10:10]"`). + If `None`, the whole document text is used. + context_formatter : Optional[Callable[[Doc], str]] + Callable used to render the context passed to the LLM. Defaults to + `lambda doc: doc.text`. + attributes : Optional[AttributesArg] + Attributes to predict. If omitted, the keys are inferred from the provided + schema. + output_schema : Optional[Union[Type[BaseModel], Type[Any], Annotated[Any, Any]]] + Pydantic model class used to validate responses and serialise few-shot + examples. If the schema is a mapping/object, it will also be used to + force the model to output a valid JSON object. + examples : Optional[Iterable[Doc]] + Few-shot examples used in prompts. + max_few_shot_examples : int + Maximum number of few-shot examples per request (`-1` means all). + use_retriever : Optional[bool] + Whether to select few-shot examples with BM25 (defaults to automatic choice). + If there are few shot examples and `max_few_shot_examples > 0`, this enabled + by default. + seed : Optional[int] + Optional seed forwarded to the API. + max_concurrent_requests : int + Maximum number of concurrent span requests per document. + api_kwargs : Dict[str, Any] + Extra keyword arguments forwarded to `chat.completions.create`. + on_error : Literal["raise", "warn"] + Error handling strategy. If `"raise"`, exceptions are raised. If `"warn"`, + exceptions are logged as warnings and processing continues. + ''' # noqa: E501 + + def __init__( + self, + nlp: PipelineProtocol, + name: str = "llm_span_qualifier", + *, + api_url: str, + model: str, + prompt: Union[ + str, Callable[[str, List[Tuple[str, str]]], List[Dict[str, str]]] + ], + span_getter: Optional[SpanGetterArg] = None, + context_getter: Optional[ContextWindow] = None, + context_formatter: Optional[Callable[[Doc], str]] = None, + attributes: Optional[AttributesArg] = None, # confit will auto cast to dict + output_schema: Optional[ + Union[ + Type[BaseModel], + Type[Any], + Annotated[Any, Any], + ] + ] = None, + examples: Optional[Iterable[Doc]] = None, + max_few_shot_examples: int = -1, + use_retriever: Optional[bool] = None, + seed: Optional[int] = None, + max_concurrent_requests: int = 1, + api_kwargs: Optional[Dict[str, Any]] = None, + on_error: Literal["raise", "warn"] = "raise", + ): + import openai + + span_getter = span_getter or {"ents": True} + self.lang = nlp.lang + self.api_url = api_url + self.model = model + self.prompt = prompt + self.context_window = ( + ContextWindow.validate(context_getter) + if context_getter is not None + else None + ) + self.context_formatter = context_formatter or (lambda doc: doc.text) + self.seed = seed + self.api_kwargs = api_kwargs or {} + self.max_concurrent_requests = max_concurrent_requests + self.on_error = on_error + + if attributes is None: + if hasattr(output_schema, "model_fields"): + attr_map = {name: True for name in output_schema.model_fields.keys()} + else: + raise ValueError( + "You must provide either `attributes` or a valid pydantic" + "`output_schema` for llm_span_qualifier." + ) + else: + attr_map = attributes + + self.scalar_schema = True + if output_schema is not None: + try: + self.scalar_schema = not issubclass(output_schema, BaseModel) + except TypeError: + self.scalar_schema = True + + if self.scalar_schema and output_schema is not None: + if not attributes or len(attributes) != 1: + raise ValueError( + "When the provided output schema is a scalar type, you must " + "provide exactly one attribute." + ) + + # This class name is produced in the json output_schema so the model + # may see this depending on API implementation ! + from pydantic import RootModel + + class Output(RootModel): + root: output_schema # type: ignore + + self.output_schema = Output # type: ignore + + else: + self.output_schema = output_schema + + self.response_format = ( + self._build_response_format(self.output_schema) + if self.output_schema + else None + ) + + self.bindings: List[ + Tuple[str, Union[bool, List[str]], str, Callable, Callable] + ] = [] + for attr_name, labels in attr_map.items(): + if ( + attr_name.startswith("_.") + or attr_name.endswith("_") + or attr_name in {"label_", "kb_id_"} + ): + attr_path = attr_name + else: + attr_path = f"_.{attr_name}" + json_key = attr_path[2:] if attr_path.startswith("_.") else attr_path + setter = BINDING_SETTERS[attr_path] + getter = BINDING_GETTERS[attr_path] + self.bindings.append((attr_path, labels, json_key, setter, getter)) + self.attributes = {path: labels for path, labels, *_ in self.bindings} + + self.examples: List[Tuple[str, str]] = [] + for doc in examples or []: + for span in get_spans(doc, span_getter): + context_doc = self._build_context_doc(span) + context_text = self.context_formatter(context_doc) + values: Dict[str, Any] = {} + for _, labels, json_key, _, getter in self.bindings: + if ( + labels is False + or labels is not True + and span.label_ not in (labels or []) + ): + continue + try: + values[json_key] = getter(span) + except Exception: # pragma: no cover + self._handle_err( + f"Failed to get attribute {attr_path!r} for span " + f"{span.text!r} in example doc {span.doc._.note_id}" + ) + if self.scalar_schema: + values = next(iter(values.values())) + if self.output_schema is not None: + try: + answer = self.output_schema.model_validate( + values + ).model_dump_json(exclude_none=True) + except Exception: # pragma: no cover + self._handle_err( + f"[llm_span_qualifier] Failed to validate example " + f"values against the output schema: {values!r}" + ) + continue + else: + answer = json.dumps(values) + self.examples.append((context_text, answer)) + + self.max_few_shot_examples = max_few_shot_examples + self.retriever = None + self.retriever_stemmer = None + if self.max_few_shot_examples > 0 and use_retriever is not False: + self.build_few_shot_retriever_(self.examples) + + api_key = os.getenv("OPENAI_API_KEY", "") + self.client = openai.Client(base_url=self.api_url, api_key=api_key) + self._async_client = openai.AsyncOpenAI(base_url=self.api_url, api_key=api_key) + + super().__init__(nlp=nlp, name=name, span_getter=span_getter) + + def _handle_err(self, msg): + if self.on_error == "raise": + raise RuntimeError(msg) + else: + warnings.warn(msg) + + def set_extensions(self) -> None: + super().set_extensions() + for attr_path, *_ in self.bindings: + if attr_path.startswith("_."): + ext_name = attr_path[2:].split(".")[0] + if not Span.has_extension(ext_name): + Span.set_extension(ext_name, default=None) + + def build_few_shot_retriever_(self, samples: List[Tuple[str, str]]) -> None: + # Same BM25 strategy as llm_markup_extractor + import bm25s + import Stemmer + + lang = {"eds": "french"}.get(self.lang, self.lang) + stemmer = Stemmer.Stemmer(lang) + corpus = bm25s.tokenize( + [text for text, _ in samples], stemmer=stemmer, stopwords=lang + ) + retriever = bm25s.BM25() + retriever.index(corpus) + self.retriever = retriever + self.retriever_stemmer = stemmer + + def build_prompt(self, context_text: str) -> List[Dict[str, str]]: + import bm25s + + few_shot_examples: List[Tuple[str, str]] = [] + if self.retriever is not None: + closest, _ = self.retriever.retrieve( + bm25s.tokenize( + context_text, + stemmer=self.retriever_stemmer, + show_progress=False, + ), + k=self.max_few_shot_examples, + show_progress=False, + ) + for i in closest[0][: self.max_few_shot_examples]: + few_shot_examples.append(self.examples[i]) + few_shot_examples = few_shot_examples[::-1] + else: + few_shot_examples = self.examples[: self.max_few_shot_examples] + + if isinstance(self.prompt, str): + messages = [{"role": "system", "content": self.prompt}] + for ctx, ans in few_shot_examples: + messages.append({"role": "user", "content": ctx}) + messages.append({"role": "assistant", "content": ans}) + messages.append({"role": "user", "content": context_text}) + return messages + return self.prompt(context_text, few_shot_examples) + + def _llm_request_sync(self, messages: List[Dict[str, str]]) -> str: + call_kwargs = dict(self.api_kwargs) + if "response_format" not in call_kwargs and self.response_format is not None: + call_kwargs["response_format"] = self.response_format + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + seed=self.seed, + **call_kwargs, + ) + return response.choices[0].message.content or "" + + def _llm_request_coro( + self, messages: List[Dict[str, str]] + ) -> Coroutine[Any, Any, str]: + async def _coro(): + call_kwargs = dict(self.api_kwargs) + if ( + "response_format" not in call_kwargs + and self.response_format is not None + ): + call_kwargs["response_format"] = self.response_format + response = await self._async_client.chat.completions.create( + model=self.model, + messages=messages, + seed=self.seed, + **call_kwargs, + ) + return response.choices[0].message.content or "" + + return _coro() + + def _parse_response(self, raw: str) -> Optional[Dict[str, Any]]: + text = raw.strip() + data = text + if self.output_schema is not None: + if self.scalar_schema: + start = 0 + end = len(text) + else: + if text.startswith("```"): + text = text.strip("`") + text = text.strip() + if text.startswith("json"): + text = text[4:].strip() + start = text.find("{") + end = text.rfind("}") + 1 + if start == -1 or end <= 0 or end <= start: + return None + try: + data = self.output_schema.model_validate_json(text).model_dump() + except Exception: + try: + # Interpret as a string + data = self.output_schema.model_validate_json( + json.dumps(text) + ).model_dump() + except Exception: # pragma: no cover + self._handle_err( + "[llm_span_qualifier] Failed to validate LLM response" + f" against the output schema: {text!r}", + ) + data = raw + if self.scalar_schema: + data = {next(iter(self.attributes.keys())): data} + return data + + def _build_context_doc(self, span: Span) -> Doc: + ctx_source = ( + span.doc[:] if self.context_window is None else self.context_window(span) + ) + context = ctx_source.as_doc() + offset = ctx_source.start + rel_start = max(0, span.start - offset) + rel_end = max(rel_start, min(len(context), span.end - offset)) + ent = Span(context, rel_start, rel_end, label=span.label_) + context.ents = (ent,) + return context + + def _set_values_on_span(self, span: Span, data: Optional[Dict[str, Any]]) -> None: + for attr_path, labels, json_key, setter, _ in self.bindings: + if ( + labels is False + or labels is not True + and span.label_ not in (labels or []) + ): + continue + # only when scalar mode we use attr_path as key, maybe change that later ? + value = ( + None + if data is None + else data.get(json_key) + if json_key in data + else data.get(attr_path) + ) + try: + setter(span, value) + except Exception as exc: # pragma: no cover + self._handle_err( + f"[llm_span_qualifier] Failed to set attribute {attr_path!r} " + f"for span {span.text!r} in doc {span.doc._.note_id}: {exc!r}" + ) + + def _build_response_format(self, schema: Type[BaseModel]) -> Dict[str, Any]: + raw_schema = schema.model_json_schema() + json_schema = json.loads(json.dumps(raw_schema)) + + if isinstance(json_schema, dict): + json_schema.setdefault("type", "object") + json_schema.setdefault("additionalProperties", False) + + return { + "type": "json_schema", + "json_schema": { + "name": schema.__name__.replace(" ", "_"), + "schema": json_schema, + }, + } + + def _process_docs_async(self, docs: Iterable[Doc]) -> Iterable[Doc]: + worker = AsyncRequestWorker.instance() + pending: Dict[int, Tuple[Dict[str, Any], Span]] = {} + doc_states: List[Dict[str, Any]] = [] + docs_iter = iter(docs) + exhausted = False + next_yield = 0 + + def make_state(doc: Doc) -> Dict[str, Any]: + spans = list(get_spans(doc, self.span_getter)) + return { + "doc": doc, + "spans": spans, + "next_span": 0, + "pending": 0, + } + + def doc_done(state: Dict[str, Any]) -> bool: + return state["next_span"] >= len(state["spans"]) and state["pending"] == 0 + + def schedule() -> None: + if next_yield >= len(doc_states): + return + for state in doc_states[next_yield:]: + while ( + state["next_span"] < len(state["spans"]) + and len(pending) < self.max_concurrent_requests + ): + span = state["spans"][state["next_span"]] + state["next_span"] += 1 + context_doc = self._build_context_doc(span) + context_text = self.context_formatter(context_doc) + messages = self.build_prompt(context_text) + task_id = worker.submit(self._llm_request_coro(messages)) + pending[task_id] = (state, span) + state["pending"] += 1 + if len(pending) >= self.max_concurrent_requests: + return + + while True: + while not exhausted and len(pending) < self.max_concurrent_requests: + try: + doc = next(docs_iter) + except StopIteration: + exhausted = True + break + doc_states.append(make_state(doc)) + schedule() + + while next_yield < len(doc_states) and doc_done( + doc_states[next_yield] + ): # pragma: no cover + yield doc_states[next_yield]["doc"] + next_yield += 1 + + if exhausted and len(pending) == 0 and next_yield == len(doc_states): + break + + if len(pending) == 0: # pragma: no cover + if exhausted and next_yield == len(doc_states): + break + continue + + done_task = worker.wait_for_any(pending.keys()) + result = worker.pop_result(done_task) + state, span = pending.pop(done_task) + state["pending"] -= 1 + raw = None + err = None + if result is not None: + raw, err = result + if err is not None: # pragma: no cover + self._handle_err( + f"[llm_span_qualifier] request failed for span " + f"'{span.text}' in doc {span.doc._.note_id}: {err!r}" + ) + data = None + else: + data = self._parse_response(str(raw)) + if data is None: # pragma: no cover + self._handle_err( + "[llm_span_qualifier] Failed to parse LLM response for span " + f"'{span.text}' in doc {span.doc._.note_id}: {raw!r}" + ) + self._set_values_on_span(span, data) + schedule() + while next_yield < len(doc_states) and doc_done(doc_states[next_yield]): + yield doc_states[next_yield]["doc"] + next_yield += 1 + + def process(self, doc: Doc) -> Doc: + spans = list(get_spans(doc, self.span_getter)) + + for span in spans: + context_doc = self._build_context_doc(span) + context_text = self.context_formatter(context_doc) + messages = self.build_prompt(context_text) + data = None + try: + raw = self._llm_request_sync(messages) + except Exception as err: # pragma: no cover + self._handle_err( + "[llm_span_qualifier] request failed for span " + f"'{span.text}' in doc {doc._.note_id}: {err!r}" + ) + else: + data = self._parse_response(raw) + if data is None: # pragma: no cover + self._handle_err( + "[llm_span_qualifier] Failed to parse LLM response for span " + f"'{span.text}' in doc {doc._.note_id}: {raw!r}" + ) + self._set_values_on_span(span, data) + return doc + + def __call__(self, doc: Doc) -> Doc: + return self.process(doc) + + def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]: + if self.max_concurrent_requests <= 1: # pragma: no cover + for doc in docs: + yield self(doc) + return + + yield from self._process_docs_async(docs) diff --git a/edsnlp/utils/asynchronous.py b/edsnlp/utils/asynchronous.py new file mode 100644 index 000000000..58cadb39e --- /dev/null +++ b/edsnlp/utils/asynchronous.py @@ -0,0 +1,37 @@ +import asyncio +from typing import Any, Coroutine, Optional, TypeVar + +T = TypeVar("T") + + +def run_async(coro: Coroutine[Any, Any, T]) -> T: + """ + Runs an asynchronous coroutine and always waits for the result, + whether or not an event loop is already running. + + In a standard Python script (no active event loop), it uses `asyncio.run()`. + In a notebook or environment with a running event loop, it applies a patch + using `nest_asyncio` and runs the coroutine via `loop.run_until_complete`. + + Parameters + ---------- + coro : Coroutine + The coroutine to run. + + Returns + ------- + T + The result returned by the coroutine. + """ + try: + loop: Optional[asyncio.AbstractEventLoop] = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): # pragma: no cover + import nest_asyncio + + nest_asyncio.apply() + return asyncio.get_running_loop().run_until_complete(coro) + else: + return asyncio.run(coro) diff --git a/mkdocs.yml b/mkdocs.yml index e648267f2..42bf425ec 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -138,6 +138,7 @@ nav: - LLM based: - pipes/llm/index.md - pipes/llm/llm-markup-extraction.md + - pipes/llm/llm-span-qualifier.md - tokenizers.md - Data Connectors: - data/index.md diff --git a/pyproject.toml b/pyproject.toml index 7afc1ca38..826dfdedb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -204,7 +204,8 @@ where = ["."] "eds.markup_to_doc" = "edsnlp.data.converters:MarkupToDocConverter" # LLM -"eds.llm_markup_extractor" = "edsnlp.pipes.llm.llm_markup_extractor.factory:create_component" +"eds.llm_markup_extractor" = "edsnlp.pipes.llm.llm_markup_extractor.factory:create_component" +"eds.llm_span_qualifier" = "edsnlp.pipes.llm.llm_span_qualifier.factory:create_component" # Deprecated (links to the same factories as above) "SOFA" = "edsnlp.pipes.ner.scores.sofa.factory:create_component" diff --git a/tests/conftest.py b/tests/conftest.py index 98e272633..bf5bb0825 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -246,3 +246,13 @@ def doc2md(nlp): @pytest.fixture def md2doc(nlp): return MarkupToDocConverter(preset="md") + + +@pytest.fixture +def doc2xml(nlp): + return DocToMarkupConverter(preset="xml") + + +@pytest.fixture +def xml2doc(nlp): + return MarkupToDocConverter(preset="xml") diff --git a/tests/pipelines/llm/test_llm_span_qualifier.py b/tests/pipelines/llm/test_llm_span_qualifier.py new file mode 100644 index 000000000..244a8db03 --- /dev/null +++ b/tests/pipelines/llm/test_llm_span_qualifier.py @@ -0,0 +1,529 @@ +import datetime +import json +import re + +import pytest + +# Also serves as a py37 skip since we don't install openai in py37 CI +pytest.importorskip("openai") + + +from typing import Optional + +from mock_llm_service import mock_llm_service +from pydantic import ( + BaseModel, + BeforeValidator, + PlainSerializer, + WithJsonSchema, + field_validator, +) +from typing_extensions import Annotated + +import edsnlp +import edsnlp.pipes as eds + +PROMPT = """\ +Predict JSON attributes for the highlighted entity. +Return keys `negation` (bool) and `date` (YYYY-MM-DD string, optional). +""" + + +class QualifierSchema(BaseModel): + negation: bool + date: Optional[datetime.date] = None + + +def assert_response_schema(response_format): + assert response_format["type"] == "json_schema" + payload = response_format["json_schema"] + assert payload["name"] + schema = payload["schema"] + assert schema.get("type") == "object" + assert schema.get("additionalProperties") is False + props = schema.get("properties", {}) + assert "negation" in props + neg_type = props["negation"].get("type") + assert neg_type in {"boolean", "bool"} + assert "date" in props + + +def test_llm_span_qualifier_sets_attributes(xml2doc, doc2xml): + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="my-custom-model", + prompt=PROMPT, + output_schema=QualifierSchema, + context_formatter=doc2xml, + ) + ) + + doc = xml2doc("Le patient n'a pas de tuberculose.") + + def responder(*, messages, response_format, **_): + assert_response_schema(response_format) + user_prompt = messages[-1]["content"] + assert "tuberculose" in user_prompt + return '{"negation": true, "date": "2024-01-02"}' + + with mock_llm_service(responder=responder): + doc = nlp(doc) + + (ent,) = doc.ents + assert ent._.negation is True + assert ent._.date is not None + + +def test_llm_span_qualifier_async_multiple_spans(xml2doc, doc2xml): + def prompt(context, examples): + assert len(examples) == 0 + messages = [] + system_content = ( + "You are a span classifier.\n" + "Answer with JSON using the keys: biopsy_procedure.\n" + "Here are some examples:\n" + ) + for ex_context, ex_json in examples: + system_content += f"- User: {ex_context}\n" + system_content += f" Assistant: {ex_json}\n" + messages.append({"role": "system", "content": system_content}) + messages.append({"role": "user", "content": context}) + return messages + + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="my-custom-model", + prompt=prompt, + output_schema=QualifierSchema, + context_formatter=doc2xml, + max_concurrent_requests=2, + on_error="raise", + context_getter="words[-2:2]", + ) + ) + + doc = xml2doc( + "Le patient a une tuberculose et une pneumonie." + ) + + def responder(*, messages, response_format, **_): + assert_response_schema(response_format) + content = messages[-1]["content"] + if content == "a une tuberculose et une ": + return '{"negation": true}' + assert content == "et une pneumonie." + return '{"negation": false, "date": "2024-06-01"}' + + with mock_llm_service(responder=responder): + doc = nlp(doc) + + values = { + ent.text: ( + ent._.negation, + ent._.date, + ) + for ent in doc.ents + } + assert values == { + "tuberculose": (True, None), + "pneumonie": (False, datetime.date(2024, 6, 1)), + } + + +def test_llm_span_qualifier_multi_text(xml2doc, doc2xml): + nlp = edsnlp.blank("eds") + + example_docs = edsnlp.data.from_iterable( + [ + "Le patient est atteint de covid.", + ( + "Diagnostic de pneumonie " + "non retenu le 1er juin 2025." + ), + ], + converter="markup", + preset="xml", + bool_attributes=["negation"], + ) + + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="my-custom-model", + prompt=PROMPT, + output_schema=QualifierSchema, + context_formatter=doc2xml, + max_concurrent_requests=3, + on_error="warn", + attributes={"_.negation": ["DIAG", "PAT"], "date": ["DIAG"]}, + examples=example_docs, + use_retriever=True, + max_few_shot_examples=2, + ) + ) + + docs = [ + xml2doc("Le patient n'a pas la covid le 10/11/12."), + xml2doc("Une pneumonie a été diagnostiquée le 15 juin 2024."), + xml2doc("On suspecte une grippe."), + ] + + def responder(*, messages, response_format, **_): + assert_response_schema(response_format) + content = messages[-1]["content"] + if "covid" in content: + return '{"negation": true, "date": "2012-11-10"}' + if "pneumonie" in content: + return '{"negation": false, "date": "2024-06-15"}' + if "patient" in content: + return '{"negation": false, "date": null}' + assert "grippe" in content + return '```json\n{"negation": false}```' + + with mock_llm_service(responder=responder): + processed = list(nlp.pipe(docs)) + + results = [ + [(ent.text, ent._.negation, ent._.date) for ent in doc.ents] + for doc in processed + ] + assert results == [ + [("patient", False, None), ("covid", True, datetime.date(2012, 11, 10))], + [("pneumonie", False, datetime.date(2024, 6, 15))], + [("grippe", False, None)], + ] + + +def test_llm_span_qualifier_async_error(xml2doc, doc2xml): + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="my-custom-model", + prompt=PROMPT, + output_schema=QualifierSchema, + context_formatter=doc2xml, + max_concurrent_requests=2, + on_error="warn", + ) + ) + + doc = xml2doc( + "Le patient a une tuberculose et une pneumonie." + ) + + def responder(*, response_format, **_): + assert_response_schema(response_format) + raise ValueError("Simulated error") + + with mock_llm_service(responder=responder), pytest.warns( + UserWarning, match="request failed" + ): + doc = nlp(doc) + + for ent in doc.ents: + assert ent._.negation is None + assert ent._.date is None + + +def test_yes_no_schema(xml2doc, doc2xml): + YesBool = Annotated[ + bool, + WithJsonSchema({"type": "string", "enum": ["yes", "no"]}), + BeforeValidator(lambda v: v.lower() in {"yes", "y", "true", "1"}), + PlainSerializer(lambda v: "yes" if v else "no", when_used="json"), + ] + + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="my-custom-model", + prompt="""\ +Determine if the highlighted entity is negated. +Answer "yes" or "no".""", + output_schema=YesBool, + attributes="negation", + context_formatter=doc2xml, + on_error="warn", + examples=edsnlp.data.from_iterable( + [ + "Le patient a eu une pneumonie le 1er mai 2024.", + "Le patient n'a pas le covid.", + ], + converter="markup", + preset="xml", + bool_attributes=["negation"], + ), + ) + ) + + doc = xml2doc("Le patient n'a pas de tuberculose.") + + def responder(*, messages, **_): + user_prompt = messages[-1]["content"] + # Ideally, LLM service should support scalar schemas + # but this isn't the case yet. + + # assert response_format["type"] == "json_schema" + # payload = response_format["json_schema"] + # assert payload["name"] + # schema = payload["schema"] + # assert schema.get("type") == "string" + # assert schema.get("enum") == ["yes", "no"] + + assert "tuberculose" in user_prompt + return "yes" + + with mock_llm_service(responder=responder): + doc = nlp(doc) + + (ent,) = doc.ents + assert ent._.negation is True + assert ent._.date is None + + +def test_empty_schema(xml2doc, doc2xml): + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="my-custom-model", + prompt="""\ +For the highlighted entity, determine the type of the entity as a single phrase. +""", + attributes="type", + context_formatter=doc2xml, + examples=edsnlp.data.from_iterable( + [ + "On prescrit du paracétamol.", + "Le patient n'a pas le covid.", + ], + converter="markup", + preset="xml", + ), + ), + ) + doc = xml2doc("Le patient n'a pas de tuberculose.") + + def responder(*, messages, **_): + user_prompt = messages[-1]["content"] + assert "tuberculose" in user_prompt + return "disease" + + with mock_llm_service(responder=responder): + doc = nlp(doc) + (ent,) = doc.ents + assert ent._.type == "disease" + + +def make_prompt_builder(system_prompt): + def prompt(context, retrieved_examples): + messages = [ + {"role": "system", "content": system_prompt}, + ] + for ctx, answer in retrieved_examples: + messages.append({"role": "user", "content": ctx}) + messages.append({"role": "assistant", "content": answer}) + messages.append({"role": "user", "content": context}) + return messages + + return prompt + + +def make_context_formatter(prefix_prompt, suffix_prompt): + def formatter(doc): + (span,) = doc.ents + context_text = doc.text.strip() + if prefix_prompt is None or suffix_prompt is None: + return context_text + prefix = prefix_prompt.format(span=span.text) + suffix = suffix_prompt.format(span=span.text) + return f"{prefix}{context_text}{suffix}" + + return formatter + + +@pytest.mark.parametrize("label", ["True", None]) +@pytest.mark.parametrize("response_mapping", [{"^True$": "1", "^False$": "0"}, None]) +def test_llm_span_qualifier_custom_formatter_sets_attributes( + xml2doc, label, response_mapping +): + system_prompt = ( + "You are a medical assistant, build to help identify dates in the text." + ) + prefix_prompt = "Is '{span}' a date? The text is as follows:\n<<< " + suffix_prompt = " >>>" + example_doc = xml2doc( + "07/12/2020 : Anapath / biopsies rectales." # noqa: E501 + ) + + class ResponseSchema(BaseModel): + test_attr: Optional[str] = None + + @field_validator("test_attr", mode="before") + def apply_mapping(cls, value): + if value is None: + return None + value_str = str(value) + if response_mapping is None: + return value_str + for pattern, mapped in response_mapping.items(): + if re.match(pattern, value_str): + return mapped + return value_str + + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="dummy", + name="llm", + prompt=make_prompt_builder(system_prompt), + output_schema=ResponseSchema, + context_getter="words[-5:5]", + context_formatter=make_context_formatter(prefix_prompt, suffix_prompt), + max_concurrent_requests=1, + max_few_shot_examples=1, + examples=[example_doc], + api_kwargs={ + "max_tokens": 10, + "temperature": 0.0, + "response_format": None, + "extra_body": None, + }, + ) + ) + + doc = xml2doc( + "En RCP du 20/02/2025, patient classé cT3a N0 M0, haut risque. IRM multiparamétrique du 10/02/2025." # noqa: E501 + ) + qualifier = nlp.get_pipe("llm") + retrieved_examples = qualifier.examples[:1] + assert len(retrieved_examples) == 1 + example_context, example_answer = retrieved_examples[0] + assert ( + example_context + == "Is '07/12/2020' a date? The text is as follows:\n<<< 07/12/2020 : Anapath / biopsies rectales >>>" # noqa: E501 + ) + context_doc = qualifier._build_context_doc(doc.ents[0]) + expected_context = qualifier.context_formatter(context_doc) + expected_messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": example_context}, + {"role": "assistant", "content": example_answer}, + {"role": "user", "content": expected_context}, + ] + assert qualifier.build_prompt(expected_context) == expected_messages + + def responder(*, messages, response_format=None, **_): + if response_format is not None: + assert_response_schema(response_format) + payload = {"test_attr": label} + return json.dumps(payload) + + with mock_llm_service(responder=responder): + doc = nlp(doc) + + for span in doc.ents: + assert hasattr(span._, "test_attr") + if label == "True": + if response_mapping is None: + assert span._.test_attr == "True" + else: + assert span._.test_attr == "1" + else: + assert span._.test_attr is None + + assert qualifier.attributes == {"_.test_attr": True} + + +@pytest.mark.parametrize( + ("prefix_prompt", "suffix_prompt"), + [ + ("Is '{span}' a date? The text is as follows:\n<<< ", " >>>"), + (None, None), + ], +) +def test_llm_span_qualifier_custom_formatter_prompt( + xml2doc, prefix_prompt, suffix_prompt +): + system_prompt = ( + "You are a medical assistant, build to help identify dates in the text." + ) + example_doc = xml2doc( + "07/12/2020 : Anapath / biopsies rectales." # noqa: E501 + ) + + class ResponseSchema(BaseModel): + test_attr: Optional[str] = None + + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="dummy", + name="llm", + prompt=make_prompt_builder(system_prompt), + output_schema=ResponseSchema, + context_getter="words[-5:5]", + context_formatter=make_context_formatter(prefix_prompt, suffix_prompt), + max_concurrent_requests=1, + max_few_shot_examples=1, + examples=[example_doc], + api_kwargs={ + "max_tokens": 10, + "temperature": 0.0, + "response_format": None, + "extra_body": None, + }, + ) + ) + + doc = xml2doc( + "En RCP du 20/02/2025, patient classé cT3a N0 M0, haut risque. IRM multiparamétrique du 10/02/2025." # noqa: E501 + ) + + qualifier = nlp.get_pipe("llm") + retrieved_examples = qualifier.examples[:1] + assert len(retrieved_examples) == 1 + example_context, example_answer = retrieved_examples[0] + expected_example_context = ( + "Is '07/12/2020' a date? The text is as follows:\n<<< 07/12/2020 : Anapath / biopsies rectales >>>" # noqa: E501 + if prefix_prompt is not None + else "07/12/2020 : Anapath / biopsies rectales" + ) + assert example_context == expected_example_context + span = doc.ents[0] + context_doc = qualifier._build_context_doc(span) + context_text = qualifier.context_formatter(context_doc) + expected_messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": example_context}, + {"role": "assistant", "content": example_answer}, + {"role": "user", "content": context_text}, + ] + + expected_context = ( + "Is '20/02/2025' a date? The text is as follows:\n<<< En RCP du 20/02/2025, patient classé cT3 >>>" # noqa: E501 + if prefix_prompt is not None + else "En RCP du 20/02/2025, patient classé cT3" + ) + assert context_text == expected_context + assert qualifier.build_prompt(context_text) == expected_messages + + def responder(*, messages, response_format=None, **_): + if response_format is not None: + assert_response_schema(response_format) + return json.dumps({"test_attr": "True"}) + + with mock_llm_service(responder=responder): + doc = nlp(doc) + + span = doc.ents[0] + assert span._.test_attr == "True" + assert qualifier.attributes == {"_.test_attr": True}