From 90187803ce553d311c6d8a6d9bdbdd77de8404cd Mon Sep 17 00:00:00 2001 From: Ariel Cohen Date: Fri, 8 Aug 2025 17:25:37 +0200 Subject: [PATCH 1/2] feat: new qualifier llm pipe --- docs/pipes/qualifiers/llm-qualifier.md | 26 ++ .../tutorials/qualifying-entities-with-llm.md | 229 ++++++++++ .../pipes/llm/llm_span_qualifier/__init__.py | 0 .../pipes/llm/llm_span_qualifier/factory.py | 7 + .../llm_span_qualifier/llm_span_qualifier.py | 419 ++++++++++++++++++ .../pipes/llm/llm_span_qualifier/llm_utils.py | 387 ++++++++++++++++ edsnlp/utils/asynchronous.py | 37 ++ .../pipelines/llm/test_llm_span_qualifier.py | 182 ++++++++ tests/pipelines/llm/test_llm_utils.py | 213 +++++++++ 9 files changed, 1500 insertions(+) create mode 100644 docs/pipes/qualifiers/llm-qualifier.md create mode 100644 docs/tutorials/qualifying-entities-with-llm.md create mode 100644 edsnlp/pipes/llm/llm_span_qualifier/__init__.py create mode 100644 edsnlp/pipes/llm/llm_span_qualifier/factory.py create mode 100644 edsnlp/pipes/llm/llm_span_qualifier/llm_span_qualifier.py create mode 100644 edsnlp/pipes/llm/llm_span_qualifier/llm_utils.py create mode 100644 edsnlp/utils/asynchronous.py create mode 100644 tests/pipelines/llm/test_llm_span_qualifier.py create mode 100644 tests/pipelines/llm/test_llm_utils.py diff --git a/docs/pipes/qualifiers/llm-qualifier.md b/docs/pipes/qualifiers/llm-qualifier.md new file mode 100644 index 0000000000..ba85519f4f --- /dev/null +++ b/docs/pipes/qualifiers/llm-qualifier.md @@ -0,0 +1,26 @@ +## LLM Span Classifier {: #edsnlp.pipes.qualifiers.llm.factory.create_component } + +::: edsnlp.pipes.qualifiers.llm.factory.create_component + options: + heading_level: 3 + show_bases: false + show_source: false + only_class_level: true + +## APIParams {: #edsnlp.pipes.qualifiers.llm.llm_qualifier.APIParams } + +::: edsnlp.pipes.qualifiers.llm.llm_qualifier.APIParams + options: + heading_level: 3 + show_bases: false + show_source: false + only_class_level: true + +## PromptConfig {: #edsnlp.pipes.qualifiers.llm.llm_qualifier.PromptConfig } + +::: edsnlp.pipes.qualifiers.llm.llm_qualifier.PromptConfig + options: + heading_level: 3 + show_bases: false + show_source: false + only_class_level: true diff --git a/docs/tutorials/qualifying-entities-with-llm.md b/docs/tutorials/qualifying-entities-with-llm.md new file mode 100644 index 0000000000..7465a42368 --- /dev/null +++ b/docs/tutorials/qualifying-entities-with-llm.md @@ -0,0 +1,229 @@ +# Using a LLM as a span qualifier +In this tutorial we woud learn how to use the `LLMSpanClassifier` pipe to qualify spans. +You should install the extra dependencies before in a python environment (python>='3.8'): +```bash +pip install edsnlp[llm] +``` + +## Using a local LLM server +We suppose that there is an available LLM server compatible with OpenAI API. +For example, using the library vllm you can launch an LLM server as follows in command line: +```bash +vllm serve Qwen/Qwen3-8B --port 8000 --enable-prefix-caching --tensor-parallel-size 1 --max-num-seqs=10 --max-num-batched-tokens=35000 +``` + +## Using an external API +You can also use the [Openai API](https://openai.com/index/openai-api/) or the [Groq API](https://groq.com/). + +!!! warning + + As you are probably working with sensitive medical data, please check whether you can use an external API or if you need to expose an API in your own infrastructure. + +## Import dependencies +```{ .python .no-check } +from datetime import datetime + +import pandas as pd + +import edsnlp +import edsnlp.pipes as eds +from edsnlp.pipes.qualifiers.llm.llm_qualifier import LLMSpanClassifier +from edsnlp.utils.span_getters import make_span_context_getter +``` +## Define prompt and examples +```{ .python .no-check } +task_prompts = { + 0: { + "normalized_task_name": "biopsy_procedure", + "system_prompt": "You are a medical assistant and you will help answering questions about dates present in clinical notes. Don't answer reasoning. " + + "We are interested in detecting biopsy dates (either procedure, analysis or result). " + + "You should answer in a JSON object following this schema {'biopsy':bool}. " + + "If there is not enough information, answer {'biopsy':'False'}." + + "\n\n#### Examples:\n", + "examples": [ + ( + "07/12/2020", + "07/12/2020 : Anapath / biopsies rectales : Muqueuse rectale normale sous réserve de fragments de petite taille.", + "{'biopsy':'True'}", + ), + ( + "24/12/2021", + "Chirurgie 24/12/2021 : Colectomie gauche + anastomose colo rectale + clearance hépatique gauche (une méta posée sur", + "{'biopsy':'False'}", + ), + ], + "prefix_prompt": "\nDetermine if '{span}' corresponds to a biopsy date. The text is as follows:\n<<< ", + "suffix_prompt": " >>>", + "json_schema": { + "properties": { + "biopsy": {"title": "Biopsy", "type": "boolean"}, + }, + "required": [ + "biopsy", + ], + "title": "DateModel", + "type": "object", + }, + "response_mapping": { + "(?i)(oui)|(yes)|(true)": "1", + "(?i)(non)|(no)|(false)|(don't)|(not)": "0", + }, + }, +} +``` + +## Format these examples for few-shot learning +```{ .python .no-check } +def format_examples(raw_examples, prefix_prompt, suffix_prompt): + examples = [] + + for date, context, answer in raw_examples: + prompt = prefix_prompt.format(span=date) + context + suffix_prompt + examples.append((prompt, answer)) + + return examples +``` + +## Set parameters and prompts +```{ .python .no-check } +# Set prompt +prompt_id = 0 +raw_examples = task_prompts.get(prompt_id).get("examples") +prefix_prompt = task_prompts.get(prompt_id).get("prefix_prompt") +user_prompt = task_prompts.get(prompt_id).get("user_prompt") +system_prompt = task_prompts.get(prompt_id).get("system_prompt") +suffix_prompt = task_prompts.get(prompt_id).get("suffix_prompt") +examples = format_examples(raw_examples, prefix_prompt, suffix_prompt) + +# Define JSON schema +response_format = { + "type": "json_schema", + "json_schema": { + "name": "DateModel", + # "strict": True, + "schema": task_prompts.get(prompt_id)["json_schema"], + }, +} + +# Set parameters +response_mapping = None +max_tokens = 200 +extra_body = { + # "chat_template_kwargs": {"enable_thinking": False}, +} +temperature = 0 +``` + +=== "For local serving" + + ```{ .python .no-check } + ### For local serving + model_name = "Qwen/Qwen3-8B" + api_url = "http://localhost:8000/v1" + api_key = "EMPTY_API_KEY" + ``` + + +=== "Using the Groq API" + !!! warning + ⚠️ This section involves the use of an external API. Please ensure you have the necessary credentials and understand the potential risks associated with external API usage. + + ```{ .python .no-check } + ### Using Groq API + model_name = "openai/gpt-oss-20b" + api_url = "https://api.groq.com/openai/v1" + api_key = "TOKEN" ## your API KEY + ``` + +## Define the pipeline +```{ .python .no-check } +nlp = edsnlp.blank("eds") +nlp.add_pipe("sentencizer") +nlp.add_pipe(eds.dates()) +nlp.add_pipe( + LLMSpanClassifier( + name="llm", + model=model_name, + span_getter=["dates"], + attributes={"_.biopsy_procedure": True}, + context_getter=make_span_context_getter( + context_sents=(3, 3), + context_words=(1, 1), + ), + prompt=dict( + system_prompt=system_prompt, + user_prompt=user_prompt, + prefix_prompt=prefix_prompt, + suffix_prompt=suffix_prompt, + examples=examples, + ), + api_params=dict( + max_tokens=max_tokens, + temperature=temperature, + response_format=response_format, + extra_body=extra_body, + ), + api_url=api_url, + api_key=api_key, + response_mapping=response_mapping, + n_concurrent_tasks=4, + ) +) +``` + +## Apply it on a document + +```{ .python .no-check } +# Let's try with a fake LLM generated text +text = """ +Centre Hospitalier Départemental – RCP Prostate – 20/02/2025 + +M. Bernard P., 69 ans, retraité, consulte après avoir noté une faiblesse du jet urinaire et des levers nocturnes répétés depuis un an. PSA à 15,2 ng/mL (05/02/2025). TR : nodule ferme sur lobe gauche. + +IRM multiparamétrique du 10/02/2025 : lésion PIRADS 5, 2,1 cm, atteinte de la capsule suspectée. +Biopsies du 12/02/2025 : adénocarcinome Gleason 4+4=8, toutes les carottes gauches positives. +Scanner TAP et scintigraphie osseuse du 14/02 : absence de métastases viscérales ou osseuses. + +En RCP du 20/02/2025, patient classé cT3a N0 M0, haut risque. Décision : radiothérapie externe + hormonothérapie longue (24 mois). Planification de la simulation scanner le 25/02. +""" +``` + +```{ .python .no-check } +t0 = datetime.now() +doc = nlp(text) +t1 = datetime.now() +print("Execution time", t1 - t0) + +for span in doc.spans["dates"]: + print(span, span._.biopsy_procedure) +``` + +Lets check the type +```{ .python .no-check } +type(span._.biopsy_procedure) +``` +# Apply on multiple documents +```{ .python .no-check } +texts = [ + text, +] * 2 + +notes = pd.DataFrame({"note_id": range(len(texts)), "note_text": texts}) +docs = edsnlp.data.from_pandas(notes, nlp=nlp, converter="omop") +predicted_docs = docs.map_pipeline(nlp, 2) +``` + +```{ .python .no-check } +t0 = datetime.now() +note_nlp = edsnlp.data.to_pandas( + predicted_docs, + converter="ents", + span_getter="dates", + span_attributes=[ + "biopsy_procedure", + ], +) +t1 = datetime.now() +print("Execution time", t1 - t0) +note_nlp.head() +``` diff --git a/edsnlp/pipes/llm/llm_span_qualifier/__init__.py b/edsnlp/pipes/llm/llm_span_qualifier/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/edsnlp/pipes/llm/llm_span_qualifier/factory.py b/edsnlp/pipes/llm/llm_span_qualifier/factory.py new file mode 100644 index 0000000000..d204a077d8 --- /dev/null +++ b/edsnlp/pipes/llm/llm_span_qualifier/factory.py @@ -0,0 +1,7 @@ +from edsnlp.core import registry + +from .llm_span_qualifier import LLMSpanClassifier + +create_component = registry.factory.register( + "eds.llm_span_qualifier", +)(LLMSpanClassifier) diff --git a/edsnlp/pipes/llm/llm_span_qualifier/llm_span_qualifier.py b/edsnlp/pipes/llm/llm_span_qualifier/llm_span_qualifier.py new file mode 100644 index 0000000000..4683ce3b6e --- /dev/null +++ b/edsnlp/pipes/llm/llm_span_qualifier/llm_span_qualifier.py @@ -0,0 +1,419 @@ +from __future__ import annotations + +import logging +import re +from typing import ( + Any, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, +) + +from spacy.tokens import Doc, Span +from typing_extensions import TypedDict + +from edsnlp.core.pipeline import Pipeline +from edsnlp.pipes.base import BaseSpanAttributeClassifierComponent +from edsnlp.pipes.qualifiers.llm.llm_utils import ( + AsyncLLM, + create_prompt_messages, +) +from edsnlp.utils.asynchronous import run_async +from edsnlp.utils.bindings import ( + BINDING_SETTERS, + Attributes, + AttributesArg, +) +from edsnlp.utils.span_getters import SpanGetterArg, get_spans + +logger = logging.getLogger(__name__) + +LLMSpanClassifierBatchInput = TypedDict( + "LLMSpanClassifierBatchInput", + { + "queries": List[str], + }, +) +""" +queries: List[str] + List of queries to send to the LLM for classification. + Each query corresponds to a span and its context. +""" + +LLMSpanClassifierBatchOutput = TypedDict( + "LLMSpanClassifierBatchOutput", + { + "labels": Optional[Union[List[str], List[List[str]]]], + }, +) +""" +labels: Optional[Union[List[str], List[List[str]]]] + The predicted labels for each query. + If `n > 1`, this will be a list of lists, where each inner list contains the + predictions for a single query. + If `n == 1`, this will be a list of strings, where each string is the prediction + for a single query. + + If the API call fails or no predictions are made, this will be None. + If `n > 1`, it will be a list of None values for each query. + If `n == 1`, it will be a single None value. + +""" + + +class PromptConfig(TypedDict, total=False): + """ + Parameters + ---------- + system_prompt : Optional[str] + A system prompt to use for the LLM. This is a general prompt that will be + prepended to each query. This prompt will be passed under the `system` role + in the OpenAI API call. + Example: "You are a medical expert. Classify the following text." + If None, no system prompt will be used. + Note: This is not the same as the `user_prompt` parameter. + user_prompt : Optional[str] + A general prompt to use for all spans. This is a prompt that will be prepended + to each span's specific prompt. This will be passed under the `user` role + in the OpenAI API call. + prefix_prompt : Optional[str] + A prefix prompt to paste after the `user_prompt` and before the selected context + of the span (using the `context_getter`). + It will be formatted specifically for each span, using the `span` variable. + Example: "Is '{span}' a Colonoscopy (procedure) date?" + suffix_prompt: Optional[str] + A suffix prompt to append at the end of the prompt. + examples : Optional[List[Tuple[str, str]]] + A list of examples to use for the prompt. Each example is a tuple of + (input, output). The input is the text to classify and the output is the + expected classification. + If None, no examples will be used. + Example: [("This is a colonoscopy date.", "colonoscopy_date")] + """ + + system_prompt: Optional[str] + user_prompt: Optional[str] + prefix_prompt: Optional[str] + suffix_prompt: Optional[str] + examples: Optional[List[Tuple[str, str]]] + + +class APIParams(TypedDict, total=False): + """ + Parameters + ---------- + extra_body : Optional[Dict[str, Any]] + Additional body parameters to pass to the vLLM API. + This can be used to pass additional parameters to the model, such as + `reasoning_parser` or `enable_reasoning`. + response_format : Optional[Dict[str, Any]] + The response format to use for the vLLM API call. + This can be used to specify how the response should be formatted. + temperature : float + The temperature for the vLLM API call. Default is 0.0 (deterministic). + max_tokens : int + The maximum number of tokens to generate in the response. + Default is 50. + """ + + max_tokens: int + temperature: float + response_format: Optional[Dict[str, Any]] + extra_body: Optional[Dict[str, Any]] + + +class LLMSpanClassifier( + BaseSpanAttributeClassifierComponent, +): + """ + The `LLMSpanClassifier` component is a LLM attribute predictor. + In this context, the span classification task consists in assigning values (boolean, + strings or any object) to attributes/extensions of spans such as: + + - `span._.negation`, + - `span._.date.mode` + - `span._.cui` + + This pipe will use an LLM API to classify previously identified spans using + the context and instructions around each span. + + Check out the LLM classifier tutorial + for examples ! + + Python >= 3.8 is required. + + Parameters + ---------- + nlp : PipelineProtocol + The pipeline object + name : str + Name of the component + prompt : Optional[PromptConfig] + The prompt configuration to use for the LLM. + api_url : str + The base URL of the vLLM OpenAI-compatible server to call. + Default: "http://localhost:8000/v1" + model : str + The name of the model to use for classification. + Default: "Qwen/Qwen3-8B" + span_getter : SpanGetterArg + How to extract the candidate spans and the attributes to predict or train on. + context_getter : Optional[Union[Callable, SpanGetterArg]] + What context to use when computing the span embeddings (defaults to the whole + document). This can be: + + - a `SpanGetterArg` to retrieve contexts from a whole document. For example + `{"section": "conclusion"}` to only use the conclusion as context (you + must ensure that all spans produced by the `span_getter` argument do fall + in the conclusion in this case) + - a callable, that gets a span and should return a context for this span. + For instance, `lambda span: span.sent` to use the sentence as context. + attributes : AttributesArg + The attributes to predict or train on. If a dict is given, keys are the + attributes and values are the labels for which the attr is allowed, or True + if the attr is allowed for all labels. + api_params : APIParams + Additional parameters for the vLLM API call. + response_mapping : Optional[Dict[str, Any]] + A mapping from regex patterns to values that will be used to map the + responses from the model to the bindings. If not provided, the raw + responses will be used. The first matching regex will be used to map the + response to the binding. + Example: `{"^yes$": True, "^no$": False}` will map "yes" to True and "no" to + False. + timeout : float + The timeout for the vLLM API call. Default is 15.0 seconds. + n_concurrent_tasks : int + The number of concurrent tasks to run when calling the vLLM API. + Default is 4. + kwargs: Dict[str, Any] + Additional keyword arguments passed to the vLLM API call. + This can include parameters like `n` for the number of responses to generate, + or any other OpenAI API parameters. + + Authors and citation + -------------------- + The `eds.llm_qualifier` component was developed by AP-HP's Data Science team. + """ + + def __init__( + self, + nlp: Optional[Pipeline] = None, + name: str = "span_classifier", + prompt: Optional[PromptConfig] = None, + api_url: str = "http://localhost:8000/v1", + model: str = "Qwen/Qwen3-8B", + *, + attributes: AttributesArg = None, + span_getter: SpanGetterArg = None, + context_getter: Optional[SpanGetterArg] = None, + response_mapping: Optional[Dict[str, Any]] = None, + api_params: APIParams = { + "max_tokens": 50, + "temperature": 0.0, + "response_format": None, + "extra_body": None, + }, + timeout: float = 15.0, + n_concurrent_tasks: int = 4, + **kwargs, + ): + if attributes is None: + raise TypeError( + "The `attributes` parameter is required. Please provide a dict of " + "attributes to predict or train on." + ) + + span_getter = span_getter or {"ents": True} + + self.bindings: List[Tuple[str, List[str], List[Any]]] = [ + (k if k.startswith("_.") else f"_.{k}", v, []) + for k, v in attributes.items() + ] + + # Store API configuration + self.api_url = api_url + self.model = model + self.extra_body = api_params.get("extra_body") + self.temperature = api_params.get("temperature") + self.max_tokens = api_params.get("max_tokens") + self.response_format = api_params.get("response_format") + self.response_mapping = response_mapping + self.kwargs = kwargs.get("kwargs") or {} + self.timeout = timeout + self.n_concurrent_tasks = n_concurrent_tasks + + # Prompt config + prompt = prompt or {} + self.prompt = prompt + self.system_prompt = prompt.get("system_prompt") + self.user_prompt = prompt.get("user_prompt") + self.prefix_prompt = prompt.get("prefix_prompt") + self.suffix_prompt = prompt.get("suffix_prompt") + self.examples = prompt.get("examples") + + super().__init__(nlp, name, span_getter=span_getter) + self.context_getter = context_getter + + if self.response_mapping: + self.get_response_mapping_regex_dict() + + @property + def attributes(self) -> Attributes: + return {qlf: labels for qlf, labels, _ in self.bindings} + + def set_extensions(self): + super().set_extensions() + for group in self.bindings: + qlf = group[0] + if qlf.startswith("_."): + qlf = qlf[2:] + if not Span.has_extension(qlf): + Span.set_extension(qlf, default=None) + + def preprocess(self, doc: Doc, **kwargs) -> Dict[str, Any]: + spans = list(get_spans(doc, self.span_getter)) + spans_text = [span.text for span in spans] + if self.context_getter is None or not callable(self.context_getter): + contexts = list(get_spans(doc, self.context_getter)) + else: + contexts = [self.context_getter(span) for span in spans] + + contexts_text = [context.text for context in contexts] + + doc_batch_messages = [] + for span_text, context_text in zip(spans_text, contexts_text): + if self.prefix_prompt: + final_user_prompt = ( + self.prefix_prompt.format(span=span_text) + context_text + ) + else: + final_user_prompt = context_text + if self.suffix_prompt: + final_user_prompt += self.suffix_prompt + + messages = create_prompt_messages( + system_prompt=self.system_prompt, + user_prompt=self.user_prompt, + examples=self.examples, + final_user_prompt=final_user_prompt, + ) + doc_batch_messages.append(messages) + + return { + "$spans": spans, + "spans_text": spans_text, + "contexts": contexts, + "contexts_text": contexts_text, + "doc_batch_messages": doc_batch_messages, + } + + def collate(self, batch: Dict[str, Sequence[Any]]) -> LLMSpanClassifierBatchInput: + collated = { + "batch_messages": [ + message for item in batch for message in item["doc_batch_messages"] + ] + } + + return collated + + # noinspection SpellCheckingInspection + def forward( + self, + batch: LLMSpanClassifierBatchInput, + ) -> Dict[str, List[Any]]: + """ + Apply the span classifier module to the document embeddings and given spans to: + - compute the loss + - and/or predict the labels of spans + + Parameters + ---------- + batch: SpanClassifierBatchInput + The input batch + + Returns + ------- + BatchOutput + """ + + # Here call the LLM API + llm = AsyncLLM( + model_name=self.model, + api_url=self.api_url, + extra_body=self.extra_body, + temperature=self.temperature, + max_tokens=self.max_tokens, + response_format=self.response_format, + timeout=self.timeout, + n_concurrent_tasks=self.n_concurrent_tasks, + **self.kwargs, + ) + pred = run_async(llm(batch_messages=batch["batch_messages"])) + + return { + "labels": pred, + } + + def get_response_mapping_regex_dict(self) -> Dict[str, str]: + self.response_mapping_regex = { + re.compile(regex): mapping_value + for regex, mapping_value in self.response_mapping.items() + } + return self.response_mapping_regex + + def map_response(self, value: str) -> str: + for ( + compiled_regex, + mapping_value, + ) in self.response_mapping_regex.items(): + if compiled_regex.search(value): + mapped_value = mapping_value + break + else: + mapped_value = None + return mapped_value + + def postprocess( + self, + docs: Sequence[Doc], + results: LLMSpanClassifierBatchOutput, + inputs: List[Dict[str, Any]], + ) -> Sequence[Doc]: + # Preprocessed docs should still be in the cache + spans = [span for sample in inputs for span in sample["$spans"]] + all_labels = results["labels"] + # For each prediction group (exclusive bindings)... + + for qlf, labels, _ in self.bindings: + for value, span in zip(all_labels, spans): + if labels is True or span.label_ in labels: + if value is None: + mapped_value = None + elif self.response_mapping is not None: + # ...assign the mapped value to the span + mapped_value = self.map_response(value) + else: + mapped_value = value + BINDING_SETTERS[qlf](span, mapped_value) + + return docs + + def batch_process(self, docs): + inputs = [self.preprocess(doc) for doc in docs] + collated = self.collate(inputs) + res = self.forward(collated) + docs = self.postprocess(docs, res, inputs) + + return docs + + def enable_cache(self, cache_id=None): + # For compatibility + pass + + def disable_cache(self, cache_id=None): + # For compatibility + pass diff --git a/edsnlp/pipes/llm/llm_span_qualifier/llm_utils.py b/edsnlp/pipes/llm/llm_span_qualifier/llm_utils.py new file mode 100644 index 0000000000..2e834f99f3 --- /dev/null +++ b/edsnlp/pipes/llm/llm_span_qualifier/llm_utils.py @@ -0,0 +1,387 @@ +import asyncio +import json +import logging +from typing import ( + Any, + AsyncIterator, + Dict, + List, + Optional, + Tuple, +) + +from openai import AsyncOpenAI +from openai.types.chat.chat_completion import ChatCompletion + +logger = logging.getLogger(__name__) + + +class AsyncLLM: + """ + AsyncLLM is an asynchronous interface for interacting with Large Language Models + (LLMs) via an API, supporting concurrent requests and batch processing. + """ + + def __init__( + self, + model_name="Qwen/Qwen3-8B", + api_url: str = "http://localhost:8000/v1", + temperature: float = 0.0, + max_tokens: int = 50, + extra_body: Optional[Dict[str, Any]] = None, + response_format: Optional[Dict[str, Any]] = None, + api_key: str = "EMPTY_API_KEY", + n_completions: int = 1, + timeout: float = 15.0, + n_concurrent_tasks: int = 4, + **kwargs, + ): + """ + Initializes the AsyncLLM class with configuration parameters for interacting + with an OpenAI-compatible API server. + + Parameters + ---------- + model_name : str, optional + Name of the model to use (default: "Qwen/Qwen3-8B"). + api_url : str, optional + Base URL of the API server (default: "http://localhost:8000/v1"). + temperature : float, optional + Sampling temperature for generation (default: 0.0). + max_tokens : int, optional + Maximum number of tokens to generate per completion (default: 50). + extra_body : Optional[Dict[str, Any]], optional + Additional parameters to include in the API request body (default: None). + response_format : Optional[Dict[str, Any]], optional + Format specification for the API response (default: None). + api_key : str, optional + API key for authentication (default: "EMPTY_API_KEY"). + n_completions : int, optional + Number of completions to request per prompt (default: 1). + timeout : float, optional + Timeout for API requests in seconds (default: 15.0). + n_concurrent_tasks : int, optional + Maximum number of concurrent tasks for API requests (default: 4). + **kwargs + Additional keyword arguments for further customization. + """ + + # Set OpenAI's API key and API base to use vLLM's API server. + self.model_name = model_name + self.temperature = temperature + self.max_tokens = max_tokens + self.extra_body = extra_body + self.response_format = response_format + self.n_completions = n_completions + self.timeout = timeout + self.kwargs = kwargs + self.n_concurrent_tasks = n_concurrent_tasks + self.responses = [] + self._lock = None + + self.client = AsyncOpenAI( + api_key=api_key, + base_url=api_url, + default_headers={"Connection": "close"}, + ) + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit - properly close the client.""" + await self.aclose() + + async def aclose(self): + """Properly close the AsyncOpenAI client to prevent resource leaks.""" + if hasattr(self, "client") and self.client is not None: + await self.client.close() + + @property + def lock(self): + if self._lock is None: + self._lock = asyncio.Lock() + return self._lock + + async def async_id_message_generator( + self, batch_messages: List[List[Dict[str, str]]] + ) -> AsyncIterator[Tuple[int, List[Dict[str, str]]]]: + """ + Generator + """ + for i, messages in enumerate(batch_messages): + yield (i, messages) + + def parse_messages( + self, response: ChatCompletion, response_format: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Parse the response from the LLM and return the content. + """ + if (response_format is not None) and (isinstance(response, ChatCompletion)): + prediction = [ + parse_json_response( + choice.message.content, response_format=response_format + ) + for choice in response.choices + ] + if self.n_completions == 1: + prediction = prediction[0] + + return prediction + + else: + return response + + async def call_llm( + self, id: int, messages: List[Dict[str, str]] + ) -> Tuple[int, ChatCompletion]: + """ + Call the LLM with the given messages and return the response. + + Parameters + ---------- + id : int + Unique identifier for the call + messages : List[Dict[str, str]] + List of messages to send to the LLM, where each message is a dictionary + with keys 'role' and 'content'. + + Returns + ------- + Tuple[int, ChatCompletion] + The id of the call and the ChatCompletion object corresponding to the + LLM response + """ + + raw_response = await asyncio.wait_for( + self.client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=self.max_tokens, + n=self.n_completions, + temperature=self.temperature, + stream=False, + response_format=self.response_format, + extra_body=self.extra_body, + **self.kwargs, + ), + timeout=self.timeout, + ) + + # Parse the response + parsed_response = self.parse_messages(raw_response, self.response_format) + + return id, parsed_response + + def store_responses(self, p_id, abbreviation_list): + """ """ + self.responses.append((p_id, abbreviation_list)) + + async def async_worker( + self, + name: str, + id_messages_tuples: AsyncIterator[Tuple[int, List[List[Dict[str, str]]]]], + ): + while True: + try: + ( + idx, + message, + ) = await anext(id_messages_tuples) # noqa: F821 + idx, response = await self.call_llm(idx, message) + + logger.info(f"Worker {name} has finished process {idx}") + except StopAsyncIteration: + # Everything has been parsed! + logger.info( + f"[{name}] Received StopAsyncIteration, worker will shutdown" + ) + break + except TimeoutError as e: + logger.error(f"[{name}] TimeoutError on chunk {idx}\n{e}") + logger.error(f"Timeout was set to {self.timeout} seconds") + if self.n_completions == 1: + response = "" + else: + response = [""] * self.n_completions + except BaseException as e: + logger.error( + f"[{name}] Exception raised on chunk {idx}\n{e}" + ) # type(e) + if self.n_completions == 1: + response = "" + else: + response = [""] * self.n_completions + async with self.lock: + self.store_responses( + idx, + response, + ) + + def sort_responses(self): + sorted_responses = [] + for i, output in sorted(self.responses, key=lambda x: x[0]): + if isinstance(output, ChatCompletion): + if self.n_completions == 1: + sorted_responses.append(output.choices[0].message.content) + else: + sorted_responses.append( + [choice.message.content for choice in output.choices] + ) + else: + sorted_responses.append(output) + + return sorted_responses + + def clean_storage(self): + del self.responses + self.responses = [] + + async def __call__(self, batch_messages: List[List[Dict[str, str]]]): + """ + Asynchronous coroutine, it should be called using the + `edsnlp.utils.asynchronous.run_async` function. + + Parameters + ---------- + batch_messages : List[List[Dict[str, str]]] + List of message batches to send to the LLM, where each batch is a list + of dictionaries with keys 'role' and 'content'. + """ + try: + # Shared prompt generator + id_messages_tuples = self.async_id_message_generator(batch_messages) + + # n concurrent tasks + tasks = { + asyncio.create_task( + self.async_worker(f"Worker-{i}", id_messages_tuples) + ) + for i in range(self.n_concurrent_tasks) + } + + await asyncio.gather(*tasks) + tasks.clear() + predictions = self.sort_responses() + self.clean_storage() + + return predictions + except Exception: + # Ensure cleanup even if an exception occurs + await self.aclose() + raise + + +def create_prompt_messages( + system_prompt: Optional[str] = None, + user_prompt: Optional[str] = None, + examples: Optional[List[Tuple[str, str]]] = None, + final_user_prompt: Optional[str] = None, +) -> List[Dict[str, str]]: + """ + Create a list of prompt messages formatted for use with a language model (LLM) API. + + system_prompt : Optional[str], default=None + The initial system prompt to set the behavior or context for the LLM. + user_prompt : Optional[str], default=None + An initial user prompt to provide context or instructions to the LLM. + examples : Optional[List[Tuple[str, str]]], default=None + A list of example (prompt, response) pairs to guide the LLM's behavior. + final_user_prompt : Optional[str], default=None + The final user prompt to be appended at the end of the message sequence. + + Returns + ------- + List[Dict[str, str]] + A list of message dictionaries, each containing a 'role' + (e.g., 'system', 'user', 'assistant') + and corresponding 'content', formatted for LLM input. + + """ + + messages = [] + if system_prompt: + messages.append( + { + "role": "system", + "content": system_prompt, + } + ) + if user_prompt: + messages.append( + { + "role": "user", + "content": user_prompt, + } + ) + if examples: + for prompt, response in examples: + messages.append( + { + "role": "user", + "content": prompt, + } + ) + messages.append( + { + "role": "assistant", + "content": response, + } + ) + if final_user_prompt: + messages.append( + { + "role": "user", + "content": final_user_prompt, + } + ) + + return messages + + +def parse_json_response( + response: str, + response_format: Optional[Dict[str, Any]] = None, + errors: str = "ignore", +) -> Dict[str, Any]: + """ + Parses a response string as JSON if a JSON schema format is specified, + otherwise returns the raw response. + + Parameters + ---------- + response : str + The response string to parse. + response_format : Optional[Dict[str, Any]], optional + A dictionary specifying the expected response format. + If it contains {"type": "json_schema"}, the response will be parsed as JSON. + Defaults to None. + errors : str, optional + Determines error handling behavior when JSON decoding fails. + If set to "ignore", returns an empty dictionary on failure. + Otherwise, returns the raw response. Defaults to "ignore". + + Returns + ------- + Dict[str, Any] + The parsed JSON object if parsing is successful and a JSON schema is specified. + If parsing fails and errors is "ignore", returns an empty dictionary. + If parsing fails and errors is not "ignore", returns the raw response string. + If no response format is specified, returns the raw response string. + """ + if response is None: + return {} + + if (response_format is not None) and (response_format.get("type") == "json_schema"): + try: + return json.loads(response.strip()) + except json.JSONDecodeError: + if errors == "ignore": + return {} + else: + return response + else: + # If no response format is specified, return the raw response + return response diff --git a/edsnlp/utils/asynchronous.py b/edsnlp/utils/asynchronous.py new file mode 100644 index 0000000000..58cadb39ec --- /dev/null +++ b/edsnlp/utils/asynchronous.py @@ -0,0 +1,37 @@ +import asyncio +from typing import Any, Coroutine, Optional, TypeVar + +T = TypeVar("T") + + +def run_async(coro: Coroutine[Any, Any, T]) -> T: + """ + Runs an asynchronous coroutine and always waits for the result, + whether or not an event loop is already running. + + In a standard Python script (no active event loop), it uses `asyncio.run()`. + In a notebook or environment with a running event loop, it applies a patch + using `nest_asyncio` and runs the coroutine via `loop.run_until_complete`. + + Parameters + ---------- + coro : Coroutine + The coroutine to run. + + Returns + ------- + T + The result returned by the coroutine. + """ + try: + loop: Optional[asyncio.AbstractEventLoop] = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): # pragma: no cover + import nest_asyncio + + nest_asyncio.apply() + return asyncio.get_running_loop().run_until_complete(coro) + else: + return asyncio.run(coro) diff --git a/tests/pipelines/llm/test_llm_span_qualifier.py b/tests/pipelines/llm/test_llm_span_qualifier.py new file mode 100644 index 0000000000..578c851ac9 --- /dev/null +++ b/tests/pipelines/llm/test_llm_span_qualifier.py @@ -0,0 +1,182 @@ +from pytest import mark + +import edsnlp +from edsnlp.pipes.qualifiers.llm.llm_qualifier import LLMSpanClassifier +from edsnlp.utils.examples import parse_example +from edsnlp.utils.span_getters import make_span_context_getter + + +@mark.parametrize("label", ["True", None]) +@mark.parametrize("response_mapping", [{"^True$": "1", "^False$": "0"}, None]) +def test_llm_span_classifier(label, response_mapping): + # Patch AsyncLLM to avoid real API calls + class DummyAsyncLLM: + def __init__(self, *args, **kwargs): + # Initialize the dummy LLM + pass + + async def __call__(self, batch_messages): + # Return a dummy label for each message + return [label for _ in batch_messages] + + import edsnlp.pipes.qualifiers.llm.llm_qualifier as llm_mod + + llm_mod.AsyncLLM = DummyAsyncLLM + + nlp = edsnlp.blank("eds") + example = "En RCP du 20/02/2025, patient classé cT3a N0 M0, haut risque. IRM multiparamétrique du 10/02/2025." # noqa: E501 + text, entities = parse_example(example) + doc = nlp(text) + doc.ents = [ + doc.char_span(ent.start_char, ent.end_char, label="date") for ent in entities + ] + + # LLMSpanClassifier + nlp.add_pipe( + LLMSpanClassifier( + nlp=nlp, + name="llm", + model="dummy", + span_getter={"ents": True}, + attributes={"_.test_attr": True}, + context_getter=make_span_context_getter( + context_sents=0, + context_words=(5, 5), + ), + prompt={ + "system_prompt": "You are a medical assistant.", + "user_prompt": "You should help us identify dates in the text.", + "prefix_prompt": "Is '{span}' a date? The text is as follows:\n<<< ", + "suffix_prompt": " >>>", + "examples": [ + ( + "\nIs'07/12/2020' a date. The text is as follows:\n<<< 07/12/2020 : Anapath / biopsies rectales. >>>", # noqa: E501 + "False", + ) + ], + }, + api_url="https://dummy", + api_params={ + "max_tokens": 10, + "temperature": 0.0, + "response_format": None, + "extra_body": None, + }, + response_mapping=response_mapping, + n_concurrent_tasks=1, + ) + ) + doc = nlp(doc) + + # Check that the extension is set and the dummy label is applied + for span in doc.ents: + assert hasattr(span._, "test_attr") + if response_mapping is not None: + if label == "True": + assert span._.test_attr == "1" + elif label is None: + assert span._.test_attr is None + + if response_mapping is None: + if label == "True": + assert span._.test_attr == label + elif label is None: + assert span._.test_attr is None + + assert nlp.get_pipe("llm").attributes == {"_.test_attr": True} + + +@mark.parametrize( + "prefix_prompt,suffix_prompt", + [("Is '{span}' a date? The text is as follows:\n<<< ", " >>>"), (None, None)], +) +def test_llm_span_classifier_preprocess(prefix_prompt, suffix_prompt): + # Patch AsyncLLM to avoid real API calls + class DummyAsyncLLM: + def __init__(self, *args, **kwargs): + # Initialize the dummy LLM + pass + + async def __call__(self, batch_messages): + # Return a dummy label for each message + return ["True" for _ in batch_messages] + + import edsnlp.pipes.qualifiers.llm.llm_qualifier as llm_mod + + llm_mod.AsyncLLM = DummyAsyncLLM + + nlp = edsnlp.blank("eds") + example = "En RCP du 20/02/2025, patient classé cT3a N0 M0, haut risque. IRM multiparamétrique du 10/02/2025." # noqa: E501 + text, entities = parse_example(example) + doc = nlp(text) + doc.ents = [ + doc.char_span(ent.start_char, ent.end_char, label="date") for ent in entities + ] + + system_prompt = "You are a medical assistant." + user_prompt = "You should help us identify dates in the text." + examples = [ + ( + "\nIs'07/12/2020' a date. The text is as follows:\n<<< 07/12/2020 : Anapath / biopsies rectales. >>>", # noqa: E501 + "False", + ) + ] + + # LLMSpanClassifier + llm = LLMSpanClassifier( + nlp=nlp, + name="llm", + model="dummy", + span_getter={"ents": True}, + attributes={"_.test_attr": True}, + context_getter=make_span_context_getter( + context_sents=0, + context_words=(5, 5), + ), + prompt={ + "system_prompt": system_prompt, + "user_prompt": user_prompt, + "prefix_prompt": prefix_prompt, + "suffix_prompt": suffix_prompt, + "examples": examples, + }, + api_url="https://dummy", + api_params={ + "max_tokens": 10, + "temperature": 0.0, + "response_format": None, + "extra_body": None, + }, + response_mapping=None, + n_concurrent_tasks=1, + ) + + inputs = llm.preprocess(doc) + if (prefix_prompt is not None) and (suffix_prompt is not None): + assert inputs["doc_batch_messages"][0] == [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + { + "role": "user", + "content": examples[0][0], + }, + {"role": "assistant", "content": examples[0][1]}, + { + "role": "user", + "content": "Is '20/02/2025' a date? The text is as follows:\n<<< En RCP du 20/02/2025, patient classé cT3 >>>", # noqa: E501 + }, + ] + else: + assert inputs["doc_batch_messages"][0] == [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + { + "role": "user", + "content": examples[0][0], + }, + {"role": "assistant", "content": examples[0][1]}, + { + "role": "user", + "content": "En RCP du 20/02/2025, patient classé cT3", # noqa: E501 + }, + ] diff --git a/tests/pipelines/llm/test_llm_utils.py b/tests/pipelines/llm/test_llm_utils.py new file mode 100644 index 0000000000..b9d936e587 --- /dev/null +++ b/tests/pipelines/llm/test_llm_utils.py @@ -0,0 +1,213 @@ +from typing import Optional + +import httpx +import respx +from openai.types.chat.chat_completion import ChatCompletion +from pytest import mark + +from edsnlp.pipes.qualifiers.llm.llm_utils import ( + AsyncLLM, + create_prompt_messages, + parse_json_response, +) +from edsnlp.utils.asynchronous import run_async + + +@mark.parametrize("n_concurrent_tasks", [1, 2]) +def test_async_llm(n_concurrent_tasks): + api_url = "http://localhost:8000/v1/" + suffix_url = "chat/completions" + llm_api = AsyncLLM(n_concurrent_tasks=n_concurrent_tasks, api_url=api_url) + + with respx.mock: + respx.post(api_url + suffix_url).mock( + side_effect=[ + httpx.Response( + 200, json={"choices": [{"message": {"content": "positive"}}]} + ), + httpx.Response( + 200, json={"choices": [{"message": {"content": "negative"}}]} + ), + ] + ) + + response = run_async( + llm_api( + batch_messages=[ + [ + {"role": "user", "content": "your prompt here"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "your second prompt here"}, + ], + [{"role": "user", "content": "your second prompt here"}], + ] + ) + ) + assert response == ["positive", "negative"] + + +def test_create_prompt_messages(): + messages = create_prompt_messages( + system_prompt="Hello", + user_prompt="Hi", + examples=[("One", "1")], + final_user_prompt="What is your name?", + ) + messages_expected = [ + {"role": "system", "content": "Hello"}, + {"role": "user", "content": "Hi"}, + {"role": "user", "content": "One"}, + {"role": "assistant", "content": "1"}, + {"role": "user", "content": "What is your name?"}, + ] + assert messages == messages_expected + messages2 = create_prompt_messages( + system_prompt="Hello", + user_prompt=None, + examples=[("One", "1")], + final_user_prompt="What is your name?", + ) + messages_expected2 = [ + {"role": "system", "content": "Hello"}, + {"role": "user", "content": "One"}, + {"role": "assistant", "content": "1"}, + {"role": "user", "content": "What is your name?"}, + ] + assert messages2 == messages_expected2 + + +def create_fake_chat_completion( + choices: int = 1, content: Optional[str] = '{"biopsy":false}' +): + fake_response_data = { + "id": "chatcmpl-fake123", + "object": "chat.completion", + "created": 1699999999, + "model": "toto", + "choices": [ + { + "index": i, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + for i in range(choices) + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 15, "total_tokens": 25}, + } + + # Create the ChatCompletion object + fake_completion = ChatCompletion.model_validate(fake_response_data) + return fake_completion + + +def test_parse_json_response(): + response = create_fake_chat_completion() + response_format = { + "type": "json_schema", + "json_schema": { + "name": "DateModel", + "schema": { + "properties": {"biopsy": {"title": "Biopsy", "type": "boolean"}}, + "required": ["biopsy"], + "title": "DateModel", + "type": "object", + }, + }, + } + + llm = AsyncLLM(n_concurrent_tasks=1) + parsed_response = llm.parse_messages(response, response_format) + assert parsed_response == {"biopsy": False} + + response = create_fake_chat_completion(content=None) + parsed_response = llm.parse_messages(response, response_format=response_format) + assert parsed_response == {} + + parsed_response = llm.parse_messages(response, response_format=None) + assert parsed_response == response + + parsed_response = llm.parse_messages(None, response_format=None) + assert parsed_response is None + + +@mark.parametrize("n_completions", [1, 2]) +def test_exception_handling(n_completions): + api_url = "http://localhost:8000/v1/" + suffix_url = "chat/completions" + llm_api = AsyncLLM( + n_concurrent_tasks=1, api_url=api_url, n_completions=n_completions + ) + + with respx.mock: + respx.post(api_url + suffix_url).mock( + side_effect=[ + httpx.Response(404, json={"choices": [{}]}), + ] + ) + + response = run_async( + llm_api( + batch_messages=[ + [{"role": "user", "content": "your prompt here"}], + ] + ) + ) + if n_completions == 1: + assert response == [""] + else: + assert response == [[""] * n_completions] + + +@mark.parametrize("errors", ["ignore", "raw"]) +def test_json_decode_error(errors): + raw_response = '{"biopsy";false}' + response_format = { + "type": "json_schema", + "json_schema": { + "name": "DateModel", + "schema": { + "properties": {"biopsy": {"title": "Biopsy", "type": "boolean"}}, + "required": ["biopsy"], + "title": "DateModel", + "type": "object", + }, + }, + } + + response = parse_json_response(raw_response, response_format, errors=errors) + if errors == "ignore": + assert response == {} + else: + assert response == raw_response + + +def test_decode_no_format(): + raw_response = '{"biopsy":false}' + + response = parse_json_response(raw_response, response_format=None) + + assert response == raw_response + + +def test_multiple_completions(n_completions=2): + api_url = "http://localhost:8000/v1/" + suffix_url = "chat/completions" + llm_api = AsyncLLM( + n_concurrent_tasks=1, api_url=api_url, n_completions=n_completions + ) + completion = create_fake_chat_completion(n_completions, content="false") + with respx.mock: + respx.post(api_url + suffix_url).mock( + side_effect=[ + httpx.Response(200, json=completion.model_dump()), + ] + ) + + response = run_async( + llm_api( + batch_messages=[ + [{"role": "user", "content": "your prompt here"}], + ] + ) + ) + assert response == [["false", "false"]] From 5ba13467e87437bccaa96b472583bf5e68c855fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Fri, 3 Oct 2025 02:31:12 +0200 Subject: [PATCH 2/2] feat: improve llm qualifier concurrency, rationalize args, standardize schema w/ pydantic --- docs/pipes/llm/index.md | 1 + docs/pipes/llm/llm-span-qualifier.md | 8 + docs/pipes/qualifiers/llm-qualifier.md | 26 - .../tutorials/qualifying-entities-with-llm.md | 229 ---- edsnlp/pipes/__init__.py | 3 +- .../pipes/llm/llm_span_qualifier/__init__.py | 1 + .../pipes/llm/llm_span_qualifier/factory.py | 7 +- .../llm_span_qualifier/llm_span_qualifier.py | 1067 +++++++++++------ .../pipes/llm/llm_span_qualifier/llm_utils.py | 387 ------ mkdocs.yml | 1 + pyproject.toml | 3 +- tests/conftest.py | 10 + .../pipelines/llm/test_llm_span_qualifier.py | 639 +++++++--- tests/pipelines/llm/test_llm_utils.py | 213 ---- 14 files changed, 1225 insertions(+), 1370 deletions(-) create mode 100644 docs/pipes/llm/llm-span-qualifier.md delete mode 100644 docs/pipes/qualifiers/llm-qualifier.md delete mode 100644 docs/tutorials/qualifying-entities-with-llm.md delete mode 100644 edsnlp/pipes/llm/llm_span_qualifier/llm_utils.py delete mode 100644 tests/pipelines/llm/test_llm_utils.py diff --git a/docs/pipes/llm/index.md b/docs/pipes/llm/index.md index 53ad30809e..377f7fa60b 100644 --- a/docs/pipes/llm/index.md +++ b/docs/pipes/llm/index.md @@ -10,5 +10,6 @@ to perform various information extraction tasks. | Component | Description | |----------------------------|-----------------------------------------------------------| | `eds.llm_markup_extractor` | Extract structured information using LLMs through markup. | +| `eds.llm_span_qualifier` | Predict attributes of spans using LLMs. | diff --git a/docs/pipes/llm/llm-span-qualifier.md b/docs/pipes/llm/llm-span-qualifier.md new file mode 100644 index 0000000000..455757c129 --- /dev/null +++ b/docs/pipes/llm/llm-span-qualifier.md @@ -0,0 +1,8 @@ +# LLM Span Qualifier {: #edsnlp.pipes.llm.llm_span_qualifier.factory.create_component } + +::: edsnlp.pipes.llm.llm_span_qualifier.factory.create_component + options: + heading_level: 3 + show_bases: false + show_source: false + only_class_level: true diff --git a/docs/pipes/qualifiers/llm-qualifier.md b/docs/pipes/qualifiers/llm-qualifier.md deleted file mode 100644 index ba85519f4f..0000000000 --- a/docs/pipes/qualifiers/llm-qualifier.md +++ /dev/null @@ -1,26 +0,0 @@ -## LLM Span Classifier {: #edsnlp.pipes.qualifiers.llm.factory.create_component } - -::: edsnlp.pipes.qualifiers.llm.factory.create_component - options: - heading_level: 3 - show_bases: false - show_source: false - only_class_level: true - -## APIParams {: #edsnlp.pipes.qualifiers.llm.llm_qualifier.APIParams } - -::: edsnlp.pipes.qualifiers.llm.llm_qualifier.APIParams - options: - heading_level: 3 - show_bases: false - show_source: false - only_class_level: true - -## PromptConfig {: #edsnlp.pipes.qualifiers.llm.llm_qualifier.PromptConfig } - -::: edsnlp.pipes.qualifiers.llm.llm_qualifier.PromptConfig - options: - heading_level: 3 - show_bases: false - show_source: false - only_class_level: true diff --git a/docs/tutorials/qualifying-entities-with-llm.md b/docs/tutorials/qualifying-entities-with-llm.md deleted file mode 100644 index 7465a42368..0000000000 --- a/docs/tutorials/qualifying-entities-with-llm.md +++ /dev/null @@ -1,229 +0,0 @@ -# Using a LLM as a span qualifier -In this tutorial we woud learn how to use the `LLMSpanClassifier` pipe to qualify spans. -You should install the extra dependencies before in a python environment (python>='3.8'): -```bash -pip install edsnlp[llm] -``` - -## Using a local LLM server -We suppose that there is an available LLM server compatible with OpenAI API. -For example, using the library vllm you can launch an LLM server as follows in command line: -```bash -vllm serve Qwen/Qwen3-8B --port 8000 --enable-prefix-caching --tensor-parallel-size 1 --max-num-seqs=10 --max-num-batched-tokens=35000 -``` - -## Using an external API -You can also use the [Openai API](https://openai.com/index/openai-api/) or the [Groq API](https://groq.com/). - -!!! warning - - As you are probably working with sensitive medical data, please check whether you can use an external API or if you need to expose an API in your own infrastructure. - -## Import dependencies -```{ .python .no-check } -from datetime import datetime - -import pandas as pd - -import edsnlp -import edsnlp.pipes as eds -from edsnlp.pipes.qualifiers.llm.llm_qualifier import LLMSpanClassifier -from edsnlp.utils.span_getters import make_span_context_getter -``` -## Define prompt and examples -```{ .python .no-check } -task_prompts = { - 0: { - "normalized_task_name": "biopsy_procedure", - "system_prompt": "You are a medical assistant and you will help answering questions about dates present in clinical notes. Don't answer reasoning. " - + "We are interested in detecting biopsy dates (either procedure, analysis or result). " - + "You should answer in a JSON object following this schema {'biopsy':bool}. " - + "If there is not enough information, answer {'biopsy':'False'}." - + "\n\n#### Examples:\n", - "examples": [ - ( - "07/12/2020", - "07/12/2020 : Anapath / biopsies rectales : Muqueuse rectale normale sous réserve de fragments de petite taille.", - "{'biopsy':'True'}", - ), - ( - "24/12/2021", - "Chirurgie 24/12/2021 : Colectomie gauche + anastomose colo rectale + clearance hépatique gauche (une méta posée sur", - "{'biopsy':'False'}", - ), - ], - "prefix_prompt": "\nDetermine if '{span}' corresponds to a biopsy date. The text is as follows:\n<<< ", - "suffix_prompt": " >>>", - "json_schema": { - "properties": { - "biopsy": {"title": "Biopsy", "type": "boolean"}, - }, - "required": [ - "biopsy", - ], - "title": "DateModel", - "type": "object", - }, - "response_mapping": { - "(?i)(oui)|(yes)|(true)": "1", - "(?i)(non)|(no)|(false)|(don't)|(not)": "0", - }, - }, -} -``` - -## Format these examples for few-shot learning -```{ .python .no-check } -def format_examples(raw_examples, prefix_prompt, suffix_prompt): - examples = [] - - for date, context, answer in raw_examples: - prompt = prefix_prompt.format(span=date) + context + suffix_prompt - examples.append((prompt, answer)) - - return examples -``` - -## Set parameters and prompts -```{ .python .no-check } -# Set prompt -prompt_id = 0 -raw_examples = task_prompts.get(prompt_id).get("examples") -prefix_prompt = task_prompts.get(prompt_id).get("prefix_prompt") -user_prompt = task_prompts.get(prompt_id).get("user_prompt") -system_prompt = task_prompts.get(prompt_id).get("system_prompt") -suffix_prompt = task_prompts.get(prompt_id).get("suffix_prompt") -examples = format_examples(raw_examples, prefix_prompt, suffix_prompt) - -# Define JSON schema -response_format = { - "type": "json_schema", - "json_schema": { - "name": "DateModel", - # "strict": True, - "schema": task_prompts.get(prompt_id)["json_schema"], - }, -} - -# Set parameters -response_mapping = None -max_tokens = 200 -extra_body = { - # "chat_template_kwargs": {"enable_thinking": False}, -} -temperature = 0 -``` - -=== "For local serving" - - ```{ .python .no-check } - ### For local serving - model_name = "Qwen/Qwen3-8B" - api_url = "http://localhost:8000/v1" - api_key = "EMPTY_API_KEY" - ``` - - -=== "Using the Groq API" - !!! warning - ⚠️ This section involves the use of an external API. Please ensure you have the necessary credentials and understand the potential risks associated with external API usage. - - ```{ .python .no-check } - ### Using Groq API - model_name = "openai/gpt-oss-20b" - api_url = "https://api.groq.com/openai/v1" - api_key = "TOKEN" ## your API KEY - ``` - -## Define the pipeline -```{ .python .no-check } -nlp = edsnlp.blank("eds") -nlp.add_pipe("sentencizer") -nlp.add_pipe(eds.dates()) -nlp.add_pipe( - LLMSpanClassifier( - name="llm", - model=model_name, - span_getter=["dates"], - attributes={"_.biopsy_procedure": True}, - context_getter=make_span_context_getter( - context_sents=(3, 3), - context_words=(1, 1), - ), - prompt=dict( - system_prompt=system_prompt, - user_prompt=user_prompt, - prefix_prompt=prefix_prompt, - suffix_prompt=suffix_prompt, - examples=examples, - ), - api_params=dict( - max_tokens=max_tokens, - temperature=temperature, - response_format=response_format, - extra_body=extra_body, - ), - api_url=api_url, - api_key=api_key, - response_mapping=response_mapping, - n_concurrent_tasks=4, - ) -) -``` - -## Apply it on a document - -```{ .python .no-check } -# Let's try with a fake LLM generated text -text = """ -Centre Hospitalier Départemental – RCP Prostate – 20/02/2025 - -M. Bernard P., 69 ans, retraité, consulte après avoir noté une faiblesse du jet urinaire et des levers nocturnes répétés depuis un an. PSA à 15,2 ng/mL (05/02/2025). TR : nodule ferme sur lobe gauche. - -IRM multiparamétrique du 10/02/2025 : lésion PIRADS 5, 2,1 cm, atteinte de la capsule suspectée. -Biopsies du 12/02/2025 : adénocarcinome Gleason 4+4=8, toutes les carottes gauches positives. -Scanner TAP et scintigraphie osseuse du 14/02 : absence de métastases viscérales ou osseuses. - -En RCP du 20/02/2025, patient classé cT3a N0 M0, haut risque. Décision : radiothérapie externe + hormonothérapie longue (24 mois). Planification de la simulation scanner le 25/02. -""" -``` - -```{ .python .no-check } -t0 = datetime.now() -doc = nlp(text) -t1 = datetime.now() -print("Execution time", t1 - t0) - -for span in doc.spans["dates"]: - print(span, span._.biopsy_procedure) -``` - -Lets check the type -```{ .python .no-check } -type(span._.biopsy_procedure) -``` -# Apply on multiple documents -```{ .python .no-check } -texts = [ - text, -] * 2 - -notes = pd.DataFrame({"note_id": range(len(texts)), "note_text": texts}) -docs = edsnlp.data.from_pandas(notes, nlp=nlp, converter="omop") -predicted_docs = docs.map_pipeline(nlp, 2) -``` - -```{ .python .no-check } -t0 = datetime.now() -note_nlp = edsnlp.data.to_pandas( - predicted_docs, - converter="ents", - span_getter="dates", - span_attributes=[ - "biopsy_procedure", - ], -) -t1 = datetime.now() -print("Execution time", t1 - t0) -note_nlp.head() -``` diff --git a/edsnlp/pipes/__init__.py b/edsnlp/pipes/__init__.py index 7a1baa3329..93a502775f 100644 --- a/edsnlp/pipes/__init__.py +++ b/edsnlp/pipes/__init__.py @@ -84,4 +84,5 @@ from .trainable.embeddings.text_cnn.factory import create_component as text_cnn from .misc.split import Split as split from .misc.explode import Explode as explode - from .llm.llm_markup_extractor import LlmMarkupExtractor as llm_markup_extractor + from .llm.llm_markup_extractor.factory import create_component as llm_markup_extractor + from .llm.llm_span_qualifier.factory import create_component as llm_span_qualifier diff --git a/edsnlp/pipes/llm/llm_span_qualifier/__init__.py b/edsnlp/pipes/llm/llm_span_qualifier/__init__.py index e69de29bb2..7a3024be31 100644 --- a/edsnlp/pipes/llm/llm_span_qualifier/__init__.py +++ b/edsnlp/pipes/llm/llm_span_qualifier/__init__.py @@ -0,0 +1 @@ +from .llm_span_qualifier import LlmSpanQualifier diff --git a/edsnlp/pipes/llm/llm_span_qualifier/factory.py b/edsnlp/pipes/llm/llm_span_qualifier/factory.py index d204a077d8..00a8f5e5e6 100644 --- a/edsnlp/pipes/llm/llm_span_qualifier/factory.py +++ b/edsnlp/pipes/llm/llm_span_qualifier/factory.py @@ -1,7 +1,8 @@ -from edsnlp.core import registry +from edsnlp import registry -from .llm_span_qualifier import LLMSpanClassifier +from .llm_span_qualifier import LlmSpanQualifier create_component = registry.factory.register( "eds.llm_span_qualifier", -)(LLMSpanClassifier) + assigns=["doc.ents", "doc.spans"], +)(LlmSpanQualifier) diff --git a/edsnlp/pipes/llm/llm_span_qualifier/llm_span_qualifier.py b/edsnlp/pipes/llm/llm_span_qualifier/llm_span_qualifier.py index 4683ce3b6e..ea94e390e0 100644 --- a/edsnlp/pipes/llm/llm_span_qualifier/llm_span_qualifier.py +++ b/edsnlp/pipes/llm/llm_span_qualifier/llm_span_qualifier.py @@ -1,419 +1,758 @@ -from __future__ import annotations - -import logging -import re +import json +import os +import warnings from typing import ( Any, + Callable, + Coroutine, Dict, + Iterable, List, Optional, - Sequence, Tuple, + Type, Union, ) +from pydantic import BaseModel from spacy.tokens import Doc, Span -from typing_extensions import TypedDict +from typing_extensions import Annotated, Literal -from edsnlp.core.pipeline import Pipeline +from edsnlp.core import PipelineProtocol from edsnlp.pipes.base import BaseSpanAttributeClassifierComponent -from edsnlp.pipes.qualifiers.llm.llm_utils import ( - AsyncLLM, - create_prompt_messages, -) -from edsnlp.utils.asynchronous import run_async -from edsnlp.utils.bindings import ( - BINDING_SETTERS, - Attributes, - AttributesArg, -) -from edsnlp.utils.span_getters import SpanGetterArg, get_spans +from edsnlp.utils.bindings import BINDING_GETTERS, BINDING_SETTERS, AttributesArg +from edsnlp.utils.span_getters import ContextWindow, SpanGetterArg, get_spans -logger = logging.getLogger(__name__) +from ..async_worker import AsyncRequestWorker -LLMSpanClassifierBatchInput = TypedDict( - "LLMSpanClassifierBatchInput", - { - "queries": List[str], - }, -) -""" -queries: List[str] - List of queries to send to the LLM for classification. - Each query corresponds to a span and its context. -""" - -LLMSpanClassifierBatchOutput = TypedDict( - "LLMSpanClassifierBatchOutput", - { - "labels": Optional[Union[List[str], List[List[str]]]], - }, -) -""" -labels: Optional[Union[List[str], List[List[str]]]] - The predicted labels for each query. - If `n > 1`, this will be a list of lists, where each inner list contains the - predictions for a single query. - If `n == 1`, this will be a list of strings, where each string is the prediction - for a single query. - If the API call fails or no predictions are made, this will be None. - If `n > 1`, it will be a list of None values for each query. - If `n == 1`, it will be a single None value. +class LlmSpanQualifier(BaseSpanAttributeClassifierComponent): + r''' + The `eds.llm_span_qualifier` component qualifies spans using a + Large Language Model (LLM) that returns structured JSON attributes. -""" + This component takes existing spans, wraps them with `` markers inside a + context window and prompts an LLM to answer with a JSON object that matches the + configured schema. The response is validated and written back on the span + extensions. + In practice, along with a system prompt that constrains the allowed attributes + and optional few-shot examples provided as previous user / assistant messages, + the component sends snippets such as: + ``` + Biopsies du 12/02/2025 : adénocarcinome. + ``` -class PromptConfig(TypedDict, total=False): - """ - Parameters - ---------- - system_prompt : Optional[str] - A system prompt to use for the LLM. This is a general prompt that will be - prepended to each query. This prompt will be passed under the `system` role - in the OpenAI API call. - Example: "You are a medical expert. Classify the following text." - If None, no system prompt will be used. - Note: This is not the same as the `user_prompt` parameter. - user_prompt : Optional[str] - A general prompt to use for all spans. This is a prompt that will be prepended - to each span's specific prompt. This will be passed under the `user` role - in the OpenAI API call. - prefix_prompt : Optional[str] - A prefix prompt to paste after the `user_prompt` and before the selected context - of the span (using the `context_getter`). - It will be formatted specifically for each span, using the `span` variable. - Example: "Is '{span}' a Colonoscopy (procedure) date?" - suffix_prompt: Optional[str] - A suffix prompt to append at the end of the prompt. - examples : Optional[List[Tuple[str, str]]] - A list of examples to use for the prompt. Each example is a tuple of - (input, output). The input is the text to classify and the output is the - expected classification. - If None, no examples will be used. - Example: [("This is a colonoscopy date.", "colonoscopy_date")] - """ - - system_prompt: Optional[str] - user_prompt: Optional[str] - prefix_prompt: Optional[str] - suffix_prompt: Optional[str] - examples: Optional[List[Tuple[str, str]]] - - -class APIParams(TypedDict, total=False): - """ - Parameters - ---------- - extra_body : Optional[Dict[str, Any]] - Additional body parameters to pass to the vLLM API. - This can be used to pass additional parameters to the model, such as - `reasoning_parser` or `enable_reasoning`. - response_format : Optional[Dict[str, Any]] - The response format to use for the vLLM API call. - This can be used to specify how the response should be formatted. - temperature : float - The temperature for the vLLM API call. Default is 0.0 (deterministic). - max_tokens : int - The maximum number of tokens to generate in the response. - Default is 50. - """ - - max_tokens: int - temperature: float - response_format: Optional[Dict[str, Any]] - extra_body: Optional[Dict[str, Any]] - - -class LLMSpanClassifier( - BaseSpanAttributeClassifierComponent, -): - """ - The `LLMSpanClassifier` component is a LLM attribute predictor. - In this context, the span classification task consists in assigning values (boolean, - strings or any object) to attributes/extensions of spans such as: - - - `span._.negation`, - - `span._.date.mode` - - `span._.cui` - - This pipe will use an LLM API to classify previously identified spans using - the context and instructions around each span. - - Check out the LLM classifier tutorial - for examples ! - - Python >= 3.8 is required. + and expects a minimal JSON answer, for example: + ```json + {"biopsy_procedure": "yes"} + ``` + which is then parsed and assigned to the span attributes. + + + !!! warning "Experimental" + + This component is experimental. The API and behavior may change in future + versions. Make sure to pin your `edsnlp` version if you use it in a project. + + !!! note "Dependencies" + + This component requires several dependencies. Run the following command to + install them: + ```bash { data-md-color-scheme="slate" } + pip install openai bm25s Stemmer + ``` + We recommend even to add them to your `pyproject.toml` or `requirements.txt`. + + Examples + -------- + If your data is sensitive, we recommend you to use a self-hosted + model with an OpenAI-compatible API, such as + [vLLM](https://github.com/vllm-project/vllm). + + Start a server with the model of your choice: + + ```bash { data-md-color-scheme="slate" } + python -m vllm.entrypoints.openai.api_server \ + --model mistral-small-24b-instruct-2501 \ + --port 8080 \ + --enable-prefix-caching + ``` + + You can then use the `llm_span_qualifier` component as follows: + + + + === "Yes/no bool classification" + + ```python { .no-check } + from typing import Annotated + from pydantic import BeforeValidator, PlainSerializer, WithJsonSchema + import edsnlp, edsnlp.pipes as eds + + BiopsySchema = Annotated[ + bool, + BeforeValidator(lambda v: str(v).lower() in {"yes", "y", "true"}), + PlainSerializer(lambda v: "yes" if v else "no", when_used="json"), + ] + + PROMPT = """ + You are a span classifier. The user sends text where the target is + marked with .... Answer ONLY with a JSON value: "yes" or + "no" indicating whether the span is a biopsy date. + """.strip() + + nlp = edsnlp.blank("eds") + nlp.add_pipe(eds.sentences()) + nlp.add_pipe(eds.dates(span_setter="ents")) + + # EDS-NLP util to create documents from Markdown or XML markup. + # This has nothing to do with the LLM component itself. The following + # will create docs with entities labelled "date", store them in doc.ents, + # and set their span._.biopsy_procedure attribute. + examples = list(edsnlp.data.from_iterable( + [ + "IRM du 10/02/2025. Biopsies du 12/02/2025 : adénocarcinome.", + "Chirurgie le 24/12/2021. Colectomie. Consultation du 26/12/2021.", + ], + converter="markup", + preset="xml", + ).map(nlp.pipes.sentences)) + + doc_to_xml = edsnlp.data.converters.DocToMarkupConverter(preset="xml") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="mistral-small-24b-instruct-2501", + prompt=PROMPT, + span_getter="ents", + context_getter="sent", + context_formatter=doc_to_xml, + attributes=["biopsy_procedure"], + output_schema=BiopsySchema, + examples=examples, + max_few_shot_examples=2, + max_concurrent_requests=4, + seed=0, + ) + ) + + text = """ + RCP Prostate – 20/02/2025 + Biopsies du 12/02/2025 : adénocarcinome Gleason 4+4=8. + Simulation scanner le 25/02/2025. + """ + doc = nlp(text) + for d in doc.ents: + print(d.text, "→ biopsy_procedure:", d._.biopsy_procedure) + # Out: 20/02/2025 → biopsy_procedure: False + # Out: 12/02/2025 → biopsy_procedure: True + # Out: 25/02/2025 → biopsy_procedure: False + ``` + + === "Multi-attribute classification" + + ```python { .no-check } + from typing import Annotated, Optional + import datetime + from pydantic import BaseModel, Field + import edsnlp, edsnlp.pipes as eds + + # Pydantic schema used to validate the LLM response, serialize the + # few-shot example answers constrain the model output. + class CovidMentionSchema(BaseModel): + negation: bool = Field(..., description="Is the span negated or not") + date: Optional[datetime.date] = Field( + None, description="Date associated with the span, if any" + ) + + PROMPT = """ + You are a span classifier. For every piece of markup-annotated text the + user provides, you predict the attributes of the annotated spans. + You must follow these rules strictly: + - Be consistent, similar queries must lead to similar answers. + - Do not add any comment or explanation, just provide the answer. + Example with a negation and a date: + User: "Le 1er mai 2024, le patient a été testé covid négatif" + Assistant: "{"negation": true, "date": "2024-05-01"}" + For each span, provide a JSON with a "negation" boolean attribute, set to + true if the span is negated, false otherwise. If a date is associated with + the span, provide it as a "date" attribute in ISO format (YYYY-MM-DD). + """.strip() + + nlp = edsnlp.blank("eds") + nlp.add_pipe(eds.sentences()) + nlp.add_pipe(eds.covid()) + + # EDS-NLP util to create documents from Markdown or XML markup. + # This has nothing to do with the LLM component itself. + examples = list(edsnlp.data.from_iterable( + [ + "Covid positif le 1er mai 2024.", + "Pas de covid", + # ... add more examples if you can + ], + converter="markup", preset="xml", + ).map(nlp.pipes.sentences)) + + doc_to_xml = edsnlp.data.converters.DocToMarkupConverter(preset="xml") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="https://api.openai.com/v1", + model="gpt-5-mini", + prompt=PROMPT, + span_getter="ents", + context_getter="words[-10:10]", + context_formatter=doc_to_xml, + output_schema=CovidMentionSchema, + examples=examples, + max_few_shot_examples=1, + max_concurrent_requests=4, + seed=0, + ) + ) + doc = nlp("Pas d'indication de covid le 3 mai 2024.") + (ent,) = doc.ents + print(ent.text, "→ negation:", ent._.negation, "date:", ent._.date) + # Out: covid → negation: True date: 2024-05-03 + ``` + + + + You can also control the prompt more finely by providing a callable instead of a + string. For example, to put few-shot examples in the system message and keep the + span context as the user payload: + + ```python { .no-check } + # Use this for the `prompt` argument instead of PROMPT above + def prompt(context_text, examples): + messages = [] + system_content = ( + "You are a span classifier.\n" + "Answer with JSON using the keys: biopsy_procedure.\n" + "Here are some examples:\n" + ) + for ex_context, ex_json in examples: + system_content += f"- Context: {ex_context}\n" + system_content += f" JSON: {ex_json}\n" + messages.append({"role": "system", "content": system_content}) + messages.append({"role": "user", "content": context_text}) + return messages + ``` Parameters ---------- nlp : PipelineProtocol - The pipeline object + Pipeline object. name : str - Name of the component - prompt : Optional[PromptConfig] - The prompt configuration to use for the LLM. + Component name. api_url : str - The base URL of the vLLM OpenAI-compatible server to call. - Default: "http://localhost:8000/v1" + Base URL of the OpenAI-compatible API. model : str - The name of the model to use for classification. - Default: "Qwen/Qwen3-8B" - span_getter : SpanGetterArg - How to extract the candidate spans and the attributes to predict or train on. - context_getter : Optional[Union[Callable, SpanGetterArg]] - What context to use when computing the span embeddings (defaults to the whole - document). This can be: - - - a `SpanGetterArg` to retrieve contexts from a whole document. For example - `{"section": "conclusion"}` to only use the conclusion as context (you - must ensure that all spans produced by the `span_getter` argument do fall - in the conclusion in this case) - - a callable, that gets a span and should return a context for this span. - For instance, `lambda span: span.sent` to use the sentence as context. - attributes : AttributesArg - The attributes to predict or train on. If a dict is given, keys are the - attributes and values are the labels for which the attr is allowed, or True - if the attr is allowed for all labels. - api_params : APIParams - Additional parameters for the vLLM API call. - response_mapping : Optional[Dict[str, Any]] - A mapping from regex patterns to values that will be used to map the - responses from the model to the bindings. If not provided, the raw - responses will be used. The first matching regex will be used to map the - response to the binding. - Example: `{"^yes$": True, "^no$": False}` will map "yes" to True and "no" to - False. - timeout : float - The timeout for the vLLM API call. Default is 15.0 seconds. - n_concurrent_tasks : int - The number of concurrent tasks to run when calling the vLLM API. - Default is 4. - kwargs: Dict[str, Any] - Additional keyword arguments passed to the vLLM API call. - This can include parameters like `n` for the number of responses to generate, - or any other OpenAI API parameters. - - Authors and citation - -------------------- - The `eds.llm_qualifier` component was developed by AP-HP's Data Science team. - """ + Model identifier exposed by the API. + prompt : Union[str, Callable[[str, List[Tuple[str, str]]], List[Dict[str, str]]]] + The prompt is the main way to control the model's behavior. + It can be either: + + - A string, which will be used as a system prompt. + Few-shot examples (if any) will be provided as user/assistant + messages before the actual user query. + - A callable that takes two arguments and returns a list of messages in the + format expected by the OpenAI chat completions API. + + * `context`: the context text with the target span marked up + * `examples`: a list of few-shot examples, each being a tuple of + (context, answer) + span_getter : Optional[SpanGetterArg] + Spans to classify. Defaults to `{"ents": True}`. + context_getter : Optional[ContextWindow] + Optional context window specification (e.g. `"sent"`, `"words[-10:10]"`). + If `None`, the whole document text is used. + context_formatter : Optional[Callable[[Doc], str]] + Callable used to render the context passed to the LLM. Defaults to + `lambda doc: doc.text`. + attributes : Optional[AttributesArg] + Attributes to predict. If omitted, the keys are inferred from the provided + schema. + output_schema : Optional[Union[Type[BaseModel], Type[Any], Annotated[Any, Any]]] + Pydantic model class used to validate responses and serialise few-shot + examples. If the schema is a mapping/object, it will also be used to + force the model to output a valid JSON object. + examples : Optional[Iterable[Doc]] + Few-shot examples used in prompts. + max_few_shot_examples : int + Maximum number of few-shot examples per request (`-1` means all). + use_retriever : Optional[bool] + Whether to select few-shot examples with BM25 (defaults to automatic choice). + If there are few shot examples and `max_few_shot_examples > 0`, this enabled + by default. + seed : Optional[int] + Optional seed forwarded to the API. + max_concurrent_requests : int + Maximum number of concurrent span requests per document. + api_kwargs : Dict[str, Any] + Extra keyword arguments forwarded to `chat.completions.create`. + on_error : Literal["raise", "warn"] + Error handling strategy. If `"raise"`, exceptions are raised. If `"warn"`, + exceptions are logged as warnings and processing continues. + ''' # noqa: E501 def __init__( self, - nlp: Optional[Pipeline] = None, - name: str = "span_classifier", - prompt: Optional[PromptConfig] = None, - api_url: str = "http://localhost:8000/v1", - model: str = "Qwen/Qwen3-8B", + nlp: PipelineProtocol, + name: str = "llm_span_qualifier", *, - attributes: AttributesArg = None, - span_getter: SpanGetterArg = None, - context_getter: Optional[SpanGetterArg] = None, - response_mapping: Optional[Dict[str, Any]] = None, - api_params: APIParams = { - "max_tokens": 50, - "temperature": 0.0, - "response_format": None, - "extra_body": None, - }, - timeout: float = 15.0, - n_concurrent_tasks: int = 4, - **kwargs, + api_url: str, + model: str, + prompt: Union[ + str, Callable[[str, List[Tuple[str, str]]], List[Dict[str, str]]] + ], + span_getter: Optional[SpanGetterArg] = None, + context_getter: Optional[ContextWindow] = None, + context_formatter: Optional[Callable[[Doc], str]] = None, + attributes: Optional[AttributesArg] = None, # confit will auto cast to dict + output_schema: Optional[ + Union[ + Type[BaseModel], + Type[Any], + Annotated[Any, Any], + ] + ] = None, + examples: Optional[Iterable[Doc]] = None, + max_few_shot_examples: int = -1, + use_retriever: Optional[bool] = None, + seed: Optional[int] = None, + max_concurrent_requests: int = 1, + api_kwargs: Optional[Dict[str, Any]] = None, + on_error: Literal["raise", "warn"] = "raise", ): - if attributes is None: - raise TypeError( - "The `attributes` parameter is required. Please provide a dict of " - "attributes to predict or train on." - ) + import openai span_getter = span_getter or {"ents": True} - - self.bindings: List[Tuple[str, List[str], List[Any]]] = [ - (k if k.startswith("_.") else f"_.{k}", v, []) - for k, v in attributes.items() - ] - - # Store API configuration + self.lang = nlp.lang self.api_url = api_url self.model = model - self.extra_body = api_params.get("extra_body") - self.temperature = api_params.get("temperature") - self.max_tokens = api_params.get("max_tokens") - self.response_format = api_params.get("response_format") - self.response_mapping = response_mapping - self.kwargs = kwargs.get("kwargs") or {} - self.timeout = timeout - self.n_concurrent_tasks = n_concurrent_tasks - - # Prompt config - prompt = prompt or {} self.prompt = prompt - self.system_prompt = prompt.get("system_prompt") - self.user_prompt = prompt.get("user_prompt") - self.prefix_prompt = prompt.get("prefix_prompt") - self.suffix_prompt = prompt.get("suffix_prompt") - self.examples = prompt.get("examples") + self.context_window = ( + ContextWindow.validate(context_getter) + if context_getter is not None + else None + ) + self.context_formatter = context_formatter or (lambda doc: doc.text) + self.seed = seed + self.api_kwargs = api_kwargs or {} + self.max_concurrent_requests = max_concurrent_requests + self.on_error = on_error + + if attributes is None: + if hasattr(output_schema, "model_fields"): + attr_map = {name: True for name in output_schema.model_fields.keys()} + else: + raise ValueError( + "You must provide either `attributes` or a valid pydantic" + "`output_schema` for llm_span_qualifier." + ) + else: + attr_map = attributes + + self.scalar_schema = True + if output_schema is not None: + try: + self.scalar_schema = not issubclass(output_schema, BaseModel) + except TypeError: + self.scalar_schema = True + + if self.scalar_schema and output_schema is not None: + if not attributes or len(attributes) != 1: + raise ValueError( + "When the provided output schema is a scalar type, you must " + "provide exactly one attribute." + ) - super().__init__(nlp, name, span_getter=span_getter) - self.context_getter = context_getter + # This class name is produced in the json output_schema so the model + # may see this depending on API implementation ! + from pydantic import RootModel - if self.response_mapping: - self.get_response_mapping_regex_dict() + class Output(RootModel): + root: output_schema # type: ignore - @property - def attributes(self) -> Attributes: - return {qlf: labels for qlf, labels, _ in self.bindings} + self.output_schema = Output # type: ignore - def set_extensions(self): - super().set_extensions() - for group in self.bindings: - qlf = group[0] - if qlf.startswith("_."): - qlf = qlf[2:] - if not Span.has_extension(qlf): - Span.set_extension(qlf, default=None) - - def preprocess(self, doc: Doc, **kwargs) -> Dict[str, Any]: - spans = list(get_spans(doc, self.span_getter)) - spans_text = [span.text for span in spans] - if self.context_getter is None or not callable(self.context_getter): - contexts = list(get_spans(doc, self.context_getter)) else: - contexts = [self.context_getter(span) for span in spans] + self.output_schema = output_schema - contexts_text = [context.text for context in contexts] + self.response_format = ( + self._build_response_format(self.output_schema) + if self.output_schema + else None + ) - doc_batch_messages = [] - for span_text, context_text in zip(spans_text, contexts_text): - if self.prefix_prompt: - final_user_prompt = ( - self.prefix_prompt.format(span=span_text) + context_text - ) + self.bindings: List[ + Tuple[str, Union[bool, List[str]], str, Callable, Callable] + ] = [] + for attr_name, labels in attr_map.items(): + if ( + attr_name.startswith("_.") + or attr_name.endswith("_") + or attr_name in {"label_", "kb_id_"} + ): + attr_path = attr_name else: - final_user_prompt = context_text - if self.suffix_prompt: - final_user_prompt += self.suffix_prompt - - messages = create_prompt_messages( - system_prompt=self.system_prompt, - user_prompt=self.user_prompt, - examples=self.examples, - final_user_prompt=final_user_prompt, - ) - doc_batch_messages.append(messages) + attr_path = f"_.{attr_name}" + json_key = attr_path[2:] if attr_path.startswith("_.") else attr_path + setter = BINDING_SETTERS[attr_path] + getter = BINDING_GETTERS[attr_path] + self.bindings.append((attr_path, labels, json_key, setter, getter)) + self.attributes = {path: labels for path, labels, *_ in self.bindings} + + self.examples: List[Tuple[str, str]] = [] + for doc in examples or []: + for span in get_spans(doc, span_getter): + context_doc = self._build_context_doc(span) + context_text = self.context_formatter(context_doc) + values: Dict[str, Any] = {} + for _, labels, json_key, _, getter in self.bindings: + if ( + labels is False + or labels is not True + and span.label_ not in (labels or []) + ): + continue + try: + values[json_key] = getter(span) + except Exception: # pragma: no cover + self._handle_err( + f"Failed to get attribute {attr_path!r} for span " + f"{span.text!r} in example doc {span.doc._.note_id}" + ) + if self.scalar_schema: + values = next(iter(values.values())) + if self.output_schema is not None: + try: + answer = self.output_schema.model_validate( + values + ).model_dump_json(exclude_none=True) + except Exception: # pragma: no cover + self._handle_err( + f"[llm_span_qualifier] Failed to validate example " + f"values against the output schema: {values!r}" + ) + continue + else: + answer = json.dumps(values) + self.examples.append((context_text, answer)) + + self.max_few_shot_examples = max_few_shot_examples + self.retriever = None + self.retriever_stemmer = None + if self.max_few_shot_examples > 0 and use_retriever is not False: + self.build_few_shot_retriever_(self.examples) + + api_key = os.getenv("OPENAI_API_KEY", "") + self.client = openai.Client(base_url=self.api_url, api_key=api_key) + self._async_client = openai.AsyncOpenAI(base_url=self.api_url, api_key=api_key) + + super().__init__(nlp=nlp, name=name, span_getter=span_getter) + + def _handle_err(self, msg): + if self.on_error == "raise": + raise RuntimeError(msg) + else: + warnings.warn(msg) - return { - "$spans": spans, - "spans_text": spans_text, - "contexts": contexts, - "contexts_text": contexts_text, - "doc_batch_messages": doc_batch_messages, - } + def set_extensions(self) -> None: + super().set_extensions() + for attr_path, *_ in self.bindings: + if attr_path.startswith("_."): + ext_name = attr_path[2:].split(".")[0] + if not Span.has_extension(ext_name): + Span.set_extension(ext_name, default=None) + + def build_few_shot_retriever_(self, samples: List[Tuple[str, str]]) -> None: + # Same BM25 strategy as llm_markup_extractor + import bm25s + import Stemmer + + lang = {"eds": "french"}.get(self.lang, self.lang) + stemmer = Stemmer.Stemmer(lang) + corpus = bm25s.tokenize( + [text for text, _ in samples], stemmer=stemmer, stopwords=lang + ) + retriever = bm25s.BM25() + retriever.index(corpus) + self.retriever = retriever + self.retriever_stemmer = stemmer + + def build_prompt(self, context_text: str) -> List[Dict[str, str]]: + import bm25s + + few_shot_examples: List[Tuple[str, str]] = [] + if self.retriever is not None: + closest, _ = self.retriever.retrieve( + bm25s.tokenize( + context_text, + stemmer=self.retriever_stemmer, + show_progress=False, + ), + k=self.max_few_shot_examples, + show_progress=False, + ) + for i in closest[0][: self.max_few_shot_examples]: + few_shot_examples.append(self.examples[i]) + few_shot_examples = few_shot_examples[::-1] + else: + few_shot_examples = self.examples[: self.max_few_shot_examples] + + if isinstance(self.prompt, str): + messages = [{"role": "system", "content": self.prompt}] + for ctx, ans in few_shot_examples: + messages.append({"role": "user", "content": ctx}) + messages.append({"role": "assistant", "content": ans}) + messages.append({"role": "user", "content": context_text}) + return messages + return self.prompt(context_text, few_shot_examples) + + def _llm_request_sync(self, messages: List[Dict[str, str]]) -> str: + call_kwargs = dict(self.api_kwargs) + if "response_format" not in call_kwargs and self.response_format is not None: + call_kwargs["response_format"] = self.response_format + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + seed=self.seed, + **call_kwargs, + ) + return response.choices[0].message.content or "" + + def _llm_request_coro( + self, messages: List[Dict[str, str]] + ) -> Coroutine[Any, Any, str]: + async def _coro(): + call_kwargs = dict(self.api_kwargs) + if ( + "response_format" not in call_kwargs + and self.response_format is not None + ): + call_kwargs["response_format"] = self.response_format + response = await self._async_client.chat.completions.create( + model=self.model, + messages=messages, + seed=self.seed, + **call_kwargs, + ) + return response.choices[0].message.content or "" - def collate(self, batch: Dict[str, Sequence[Any]]) -> LLMSpanClassifierBatchInput: - collated = { - "batch_messages": [ - message for item in batch for message in item["doc_batch_messages"] - ] - } + return _coro() - return collated + def _parse_response(self, raw: str) -> Optional[Dict[str, Any]]: + text = raw.strip() + data = text + if self.output_schema is not None: + if self.scalar_schema: + start = 0 + end = len(text) + else: + if text.startswith("```"): + text = text.strip("`") + text = text.strip() + if text.startswith("json"): + text = text[4:].strip() + start = text.find("{") + end = text.rfind("}") + 1 + if start == -1 or end <= 0 or end <= start: + return None + try: + data = self.output_schema.model_validate_json(text).model_dump() + except Exception: + try: + # Interpret as a string + data = self.output_schema.model_validate_json( + json.dumps(text) + ).model_dump() + except Exception: # pragma: no cover + self._handle_err( + "[llm_span_qualifier] Failed to validate LLM response" + f" against the output schema: {text!r}", + ) + data = raw + if self.scalar_schema: + data = {next(iter(self.attributes.keys())): data} + return data + + def _build_context_doc(self, span: Span) -> Doc: + ctx_source = ( + span.doc[:] if self.context_window is None else self.context_window(span) + ) + context = ctx_source.as_doc() + offset = ctx_source.start + rel_start = max(0, span.start - offset) + rel_end = max(rel_start, min(len(context), span.end - offset)) + ent = Span(context, rel_start, rel_end, label=span.label_) + context.ents = (ent,) + return context + + def _set_values_on_span(self, span: Span, data: Optional[Dict[str, Any]]) -> None: + for attr_path, labels, json_key, setter, _ in self.bindings: + if ( + labels is False + or labels is not True + and span.label_ not in (labels or []) + ): + continue + # only when scalar mode we use attr_path as key, maybe change that later ? + value = ( + None + if data is None + else data.get(json_key) + if json_key in data + else data.get(attr_path) + ) + try: + setter(span, value) + except Exception as exc: # pragma: no cover + self._handle_err( + f"[llm_span_qualifier] Failed to set attribute {attr_path!r} " + f"for span {span.text!r} in doc {span.doc._.note_id}: {exc!r}" + ) - # noinspection SpellCheckingInspection - def forward( - self, - batch: LLMSpanClassifierBatchInput, - ) -> Dict[str, List[Any]]: - """ - Apply the span classifier module to the document embeddings and given spans to: - - compute the loss - - and/or predict the labels of spans - - Parameters - ---------- - batch: SpanClassifierBatchInput - The input batch - - Returns - ------- - BatchOutput - """ + def _build_response_format(self, schema: Type[BaseModel]) -> Dict[str, Any]: + raw_schema = schema.model_json_schema() + json_schema = json.loads(json.dumps(raw_schema)) - # Here call the LLM API - llm = AsyncLLM( - model_name=self.model, - api_url=self.api_url, - extra_body=self.extra_body, - temperature=self.temperature, - max_tokens=self.max_tokens, - response_format=self.response_format, - timeout=self.timeout, - n_concurrent_tasks=self.n_concurrent_tasks, - **self.kwargs, - ) - pred = run_async(llm(batch_messages=batch["batch_messages"])) + if isinstance(json_schema, dict): + json_schema.setdefault("type", "object") + json_schema.setdefault("additionalProperties", False) return { - "labels": pred, + "type": "json_schema", + "json_schema": { + "name": schema.__name__.replace(" ", "_"), + "schema": json_schema, + }, } - def get_response_mapping_regex_dict(self) -> Dict[str, str]: - self.response_mapping_regex = { - re.compile(regex): mapping_value - for regex, mapping_value in self.response_mapping.items() - } - return self.response_mapping_regex - - def map_response(self, value: str) -> str: - for ( - compiled_regex, - mapping_value, - ) in self.response_mapping_regex.items(): - if compiled_regex.search(value): - mapped_value = mapping_value + def _process_docs_async(self, docs: Iterable[Doc]) -> Iterable[Doc]: + worker = AsyncRequestWorker.instance() + pending: Dict[int, Tuple[Dict[str, Any], Span]] = {} + doc_states: List[Dict[str, Any]] = [] + docs_iter = iter(docs) + exhausted = False + next_yield = 0 + + def make_state(doc: Doc) -> Dict[str, Any]: + spans = list(get_spans(doc, self.span_getter)) + return { + "doc": doc, + "spans": spans, + "next_span": 0, + "pending": 0, + } + + def doc_done(state: Dict[str, Any]) -> bool: + return state["next_span"] >= len(state["spans"]) and state["pending"] == 0 + + def schedule() -> None: + if next_yield >= len(doc_states): + return + for state in doc_states[next_yield:]: + while ( + state["next_span"] < len(state["spans"]) + and len(pending) < self.max_concurrent_requests + ): + span = state["spans"][state["next_span"]] + state["next_span"] += 1 + context_doc = self._build_context_doc(span) + context_text = self.context_formatter(context_doc) + messages = self.build_prompt(context_text) + task_id = worker.submit(self._llm_request_coro(messages)) + pending[task_id] = (state, span) + state["pending"] += 1 + if len(pending) >= self.max_concurrent_requests: + return + + while True: + while not exhausted and len(pending) < self.max_concurrent_requests: + try: + doc = next(docs_iter) + except StopIteration: + exhausted = True + break + doc_states.append(make_state(doc)) + schedule() + + while next_yield < len(doc_states) and doc_done( + doc_states[next_yield] + ): # pragma: no cover + yield doc_states[next_yield]["doc"] + next_yield += 1 + + if exhausted and len(pending) == 0 and next_yield == len(doc_states): break + + if len(pending) == 0: # pragma: no cover + if exhausted and next_yield == len(doc_states): + break + continue + + done_task = worker.wait_for_any(pending.keys()) + result = worker.pop_result(done_task) + state, span = pending.pop(done_task) + state["pending"] -= 1 + raw = None + err = None + if result is not None: + raw, err = result + if err is not None: # pragma: no cover + self._handle_err( + f"[llm_span_qualifier] request failed for span " + f"'{span.text}' in doc {span.doc._.note_id}: {err!r}" + ) + data = None else: - mapped_value = None - return mapped_value + data = self._parse_response(str(raw)) + if data is None: # pragma: no cover + self._handle_err( + "[llm_span_qualifier] Failed to parse LLM response for span " + f"'{span.text}' in doc {span.doc._.note_id}: {raw!r}" + ) + self._set_values_on_span(span, data) + schedule() + while next_yield < len(doc_states) and doc_done(doc_states[next_yield]): + yield doc_states[next_yield]["doc"] + next_yield += 1 + + def process(self, doc: Doc) -> Doc: + spans = list(get_spans(doc, self.span_getter)) - def postprocess( - self, - docs: Sequence[Doc], - results: LLMSpanClassifierBatchOutput, - inputs: List[Dict[str, Any]], - ) -> Sequence[Doc]: - # Preprocessed docs should still be in the cache - spans = [span for sample in inputs for span in sample["$spans"]] - all_labels = results["labels"] - # For each prediction group (exclusive bindings)... - - for qlf, labels, _ in self.bindings: - for value, span in zip(all_labels, spans): - if labels is True or span.label_ in labels: - if value is None: - mapped_value = None - elif self.response_mapping is not None: - # ...assign the mapped value to the span - mapped_value = self.map_response(value) - else: - mapped_value = value - BINDING_SETTERS[qlf](span, mapped_value) - - return docs - - def batch_process(self, docs): - inputs = [self.preprocess(doc) for doc in docs] - collated = self.collate(inputs) - res = self.forward(collated) - docs = self.postprocess(docs, res, inputs) - - return docs - - def enable_cache(self, cache_id=None): - # For compatibility - pass - - def disable_cache(self, cache_id=None): - # For compatibility - pass + for span in spans: + context_doc = self._build_context_doc(span) + context_text = self.context_formatter(context_doc) + messages = self.build_prompt(context_text) + data = None + try: + raw = self._llm_request_sync(messages) + except Exception as err: # pragma: no cover + self._handle_err( + "[llm_span_qualifier] request failed for span " + f"'{span.text}' in doc {doc._.note_id}: {err!r}" + ) + else: + data = self._parse_response(raw) + if data is None: # pragma: no cover + self._handle_err( + "[llm_span_qualifier] Failed to parse LLM response for span " + f"'{span.text}' in doc {doc._.note_id}: {raw!r}" + ) + self._set_values_on_span(span, data) + return doc + + def __call__(self, doc: Doc) -> Doc: + return self.process(doc) + + def pipe(self, docs: Iterable[Doc]) -> Iterable[Doc]: + if self.max_concurrent_requests <= 1: # pragma: no cover + for doc in docs: + yield self(doc) + return + + yield from self._process_docs_async(docs) diff --git a/edsnlp/pipes/llm/llm_span_qualifier/llm_utils.py b/edsnlp/pipes/llm/llm_span_qualifier/llm_utils.py deleted file mode 100644 index 2e834f99f3..0000000000 --- a/edsnlp/pipes/llm/llm_span_qualifier/llm_utils.py +++ /dev/null @@ -1,387 +0,0 @@ -import asyncio -import json -import logging -from typing import ( - Any, - AsyncIterator, - Dict, - List, - Optional, - Tuple, -) - -from openai import AsyncOpenAI -from openai.types.chat.chat_completion import ChatCompletion - -logger = logging.getLogger(__name__) - - -class AsyncLLM: - """ - AsyncLLM is an asynchronous interface for interacting with Large Language Models - (LLMs) via an API, supporting concurrent requests and batch processing. - """ - - def __init__( - self, - model_name="Qwen/Qwen3-8B", - api_url: str = "http://localhost:8000/v1", - temperature: float = 0.0, - max_tokens: int = 50, - extra_body: Optional[Dict[str, Any]] = None, - response_format: Optional[Dict[str, Any]] = None, - api_key: str = "EMPTY_API_KEY", - n_completions: int = 1, - timeout: float = 15.0, - n_concurrent_tasks: int = 4, - **kwargs, - ): - """ - Initializes the AsyncLLM class with configuration parameters for interacting - with an OpenAI-compatible API server. - - Parameters - ---------- - model_name : str, optional - Name of the model to use (default: "Qwen/Qwen3-8B"). - api_url : str, optional - Base URL of the API server (default: "http://localhost:8000/v1"). - temperature : float, optional - Sampling temperature for generation (default: 0.0). - max_tokens : int, optional - Maximum number of tokens to generate per completion (default: 50). - extra_body : Optional[Dict[str, Any]], optional - Additional parameters to include in the API request body (default: None). - response_format : Optional[Dict[str, Any]], optional - Format specification for the API response (default: None). - api_key : str, optional - API key for authentication (default: "EMPTY_API_KEY"). - n_completions : int, optional - Number of completions to request per prompt (default: 1). - timeout : float, optional - Timeout for API requests in seconds (default: 15.0). - n_concurrent_tasks : int, optional - Maximum number of concurrent tasks for API requests (default: 4). - **kwargs - Additional keyword arguments for further customization. - """ - - # Set OpenAI's API key and API base to use vLLM's API server. - self.model_name = model_name - self.temperature = temperature - self.max_tokens = max_tokens - self.extra_body = extra_body - self.response_format = response_format - self.n_completions = n_completions - self.timeout = timeout - self.kwargs = kwargs - self.n_concurrent_tasks = n_concurrent_tasks - self.responses = [] - self._lock = None - - self.client = AsyncOpenAI( - api_key=api_key, - base_url=api_url, - default_headers={"Connection": "close"}, - ) - - async def __aenter__(self): - """Async context manager entry.""" - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Async context manager exit - properly close the client.""" - await self.aclose() - - async def aclose(self): - """Properly close the AsyncOpenAI client to prevent resource leaks.""" - if hasattr(self, "client") and self.client is not None: - await self.client.close() - - @property - def lock(self): - if self._lock is None: - self._lock = asyncio.Lock() - return self._lock - - async def async_id_message_generator( - self, batch_messages: List[List[Dict[str, str]]] - ) -> AsyncIterator[Tuple[int, List[Dict[str, str]]]]: - """ - Generator - """ - for i, messages in enumerate(batch_messages): - yield (i, messages) - - def parse_messages( - self, response: ChatCompletion, response_format: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: - """ - Parse the response from the LLM and return the content. - """ - if (response_format is not None) and (isinstance(response, ChatCompletion)): - prediction = [ - parse_json_response( - choice.message.content, response_format=response_format - ) - for choice in response.choices - ] - if self.n_completions == 1: - prediction = prediction[0] - - return prediction - - else: - return response - - async def call_llm( - self, id: int, messages: List[Dict[str, str]] - ) -> Tuple[int, ChatCompletion]: - """ - Call the LLM with the given messages and return the response. - - Parameters - ---------- - id : int - Unique identifier for the call - messages : List[Dict[str, str]] - List of messages to send to the LLM, where each message is a dictionary - with keys 'role' and 'content'. - - Returns - ------- - Tuple[int, ChatCompletion] - The id of the call and the ChatCompletion object corresponding to the - LLM response - """ - - raw_response = await asyncio.wait_for( - self.client.chat.completions.create( - model=self.model_name, - messages=messages, - max_tokens=self.max_tokens, - n=self.n_completions, - temperature=self.temperature, - stream=False, - response_format=self.response_format, - extra_body=self.extra_body, - **self.kwargs, - ), - timeout=self.timeout, - ) - - # Parse the response - parsed_response = self.parse_messages(raw_response, self.response_format) - - return id, parsed_response - - def store_responses(self, p_id, abbreviation_list): - """ """ - self.responses.append((p_id, abbreviation_list)) - - async def async_worker( - self, - name: str, - id_messages_tuples: AsyncIterator[Tuple[int, List[List[Dict[str, str]]]]], - ): - while True: - try: - ( - idx, - message, - ) = await anext(id_messages_tuples) # noqa: F821 - idx, response = await self.call_llm(idx, message) - - logger.info(f"Worker {name} has finished process {idx}") - except StopAsyncIteration: - # Everything has been parsed! - logger.info( - f"[{name}] Received StopAsyncIteration, worker will shutdown" - ) - break - except TimeoutError as e: - logger.error(f"[{name}] TimeoutError on chunk {idx}\n{e}") - logger.error(f"Timeout was set to {self.timeout} seconds") - if self.n_completions == 1: - response = "" - else: - response = [""] * self.n_completions - except BaseException as e: - logger.error( - f"[{name}] Exception raised on chunk {idx}\n{e}" - ) # type(e) - if self.n_completions == 1: - response = "" - else: - response = [""] * self.n_completions - async with self.lock: - self.store_responses( - idx, - response, - ) - - def sort_responses(self): - sorted_responses = [] - for i, output in sorted(self.responses, key=lambda x: x[0]): - if isinstance(output, ChatCompletion): - if self.n_completions == 1: - sorted_responses.append(output.choices[0].message.content) - else: - sorted_responses.append( - [choice.message.content for choice in output.choices] - ) - else: - sorted_responses.append(output) - - return sorted_responses - - def clean_storage(self): - del self.responses - self.responses = [] - - async def __call__(self, batch_messages: List[List[Dict[str, str]]]): - """ - Asynchronous coroutine, it should be called using the - `edsnlp.utils.asynchronous.run_async` function. - - Parameters - ---------- - batch_messages : List[List[Dict[str, str]]] - List of message batches to send to the LLM, where each batch is a list - of dictionaries with keys 'role' and 'content'. - """ - try: - # Shared prompt generator - id_messages_tuples = self.async_id_message_generator(batch_messages) - - # n concurrent tasks - tasks = { - asyncio.create_task( - self.async_worker(f"Worker-{i}", id_messages_tuples) - ) - for i in range(self.n_concurrent_tasks) - } - - await asyncio.gather(*tasks) - tasks.clear() - predictions = self.sort_responses() - self.clean_storage() - - return predictions - except Exception: - # Ensure cleanup even if an exception occurs - await self.aclose() - raise - - -def create_prompt_messages( - system_prompt: Optional[str] = None, - user_prompt: Optional[str] = None, - examples: Optional[List[Tuple[str, str]]] = None, - final_user_prompt: Optional[str] = None, -) -> List[Dict[str, str]]: - """ - Create a list of prompt messages formatted for use with a language model (LLM) API. - - system_prompt : Optional[str], default=None - The initial system prompt to set the behavior or context for the LLM. - user_prompt : Optional[str], default=None - An initial user prompt to provide context or instructions to the LLM. - examples : Optional[List[Tuple[str, str]]], default=None - A list of example (prompt, response) pairs to guide the LLM's behavior. - final_user_prompt : Optional[str], default=None - The final user prompt to be appended at the end of the message sequence. - - Returns - ------- - List[Dict[str, str]] - A list of message dictionaries, each containing a 'role' - (e.g., 'system', 'user', 'assistant') - and corresponding 'content', formatted for LLM input. - - """ - - messages = [] - if system_prompt: - messages.append( - { - "role": "system", - "content": system_prompt, - } - ) - if user_prompt: - messages.append( - { - "role": "user", - "content": user_prompt, - } - ) - if examples: - for prompt, response in examples: - messages.append( - { - "role": "user", - "content": prompt, - } - ) - messages.append( - { - "role": "assistant", - "content": response, - } - ) - if final_user_prompt: - messages.append( - { - "role": "user", - "content": final_user_prompt, - } - ) - - return messages - - -def parse_json_response( - response: str, - response_format: Optional[Dict[str, Any]] = None, - errors: str = "ignore", -) -> Dict[str, Any]: - """ - Parses a response string as JSON if a JSON schema format is specified, - otherwise returns the raw response. - - Parameters - ---------- - response : str - The response string to parse. - response_format : Optional[Dict[str, Any]], optional - A dictionary specifying the expected response format. - If it contains {"type": "json_schema"}, the response will be parsed as JSON. - Defaults to None. - errors : str, optional - Determines error handling behavior when JSON decoding fails. - If set to "ignore", returns an empty dictionary on failure. - Otherwise, returns the raw response. Defaults to "ignore". - - Returns - ------- - Dict[str, Any] - The parsed JSON object if parsing is successful and a JSON schema is specified. - If parsing fails and errors is "ignore", returns an empty dictionary. - If parsing fails and errors is not "ignore", returns the raw response string. - If no response format is specified, returns the raw response string. - """ - if response is None: - return {} - - if (response_format is not None) and (response_format.get("type") == "json_schema"): - try: - return json.loads(response.strip()) - except json.JSONDecodeError: - if errors == "ignore": - return {} - else: - return response - else: - # If no response format is specified, return the raw response - return response diff --git a/mkdocs.yml b/mkdocs.yml index e648267f23..42bf425ecd 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -138,6 +138,7 @@ nav: - LLM based: - pipes/llm/index.md - pipes/llm/llm-markup-extraction.md + - pipes/llm/llm-span-qualifier.md - tokenizers.md - Data Connectors: - data/index.md diff --git a/pyproject.toml b/pyproject.toml index 7afc1ca381..826dfdedb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -204,7 +204,8 @@ where = ["."] "eds.markup_to_doc" = "edsnlp.data.converters:MarkupToDocConverter" # LLM -"eds.llm_markup_extractor" = "edsnlp.pipes.llm.llm_markup_extractor.factory:create_component" +"eds.llm_markup_extractor" = "edsnlp.pipes.llm.llm_markup_extractor.factory:create_component" +"eds.llm_span_qualifier" = "edsnlp.pipes.llm.llm_span_qualifier.factory:create_component" # Deprecated (links to the same factories as above) "SOFA" = "edsnlp.pipes.ner.scores.sofa.factory:create_component" diff --git a/tests/conftest.py b/tests/conftest.py index 98e2726335..bf5bb0825c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -246,3 +246,13 @@ def doc2md(nlp): @pytest.fixture def md2doc(nlp): return MarkupToDocConverter(preset="md") + + +@pytest.fixture +def doc2xml(nlp): + return DocToMarkupConverter(preset="xml") + + +@pytest.fixture +def xml2doc(nlp): + return MarkupToDocConverter(preset="xml") diff --git a/tests/pipelines/llm/test_llm_span_qualifier.py b/tests/pipelines/llm/test_llm_span_qualifier.py index 578c851ac9..244a8db034 100644 --- a/tests/pipelines/llm/test_llm_span_qualifier.py +++ b/tests/pipelines/llm/test_llm_span_qualifier.py @@ -1,182 +1,529 @@ -from pytest import mark +import datetime +import json +import re + +import pytest + +# Also serves as a py37 skip since we don't install openai in py37 CI +pytest.importorskip("openai") + + +from typing import Optional + +from mock_llm_service import mock_llm_service +from pydantic import ( + BaseModel, + BeforeValidator, + PlainSerializer, + WithJsonSchema, + field_validator, +) +from typing_extensions import Annotated import edsnlp -from edsnlp.pipes.qualifiers.llm.llm_qualifier import LLMSpanClassifier -from edsnlp.utils.examples import parse_example -from edsnlp.utils.span_getters import make_span_context_getter +import edsnlp.pipes as eds + +PROMPT = """\ +Predict JSON attributes for the highlighted entity. +Return keys `negation` (bool) and `date` (YYYY-MM-DD string, optional). +""" + + +class QualifierSchema(BaseModel): + negation: bool + date: Optional[datetime.date] = None + + +def assert_response_schema(response_format): + assert response_format["type"] == "json_schema" + payload = response_format["json_schema"] + assert payload["name"] + schema = payload["schema"] + assert schema.get("type") == "object" + assert schema.get("additionalProperties") is False + props = schema.get("properties", {}) + assert "negation" in props + neg_type = props["negation"].get("type") + assert neg_type in {"boolean", "bool"} + assert "date" in props + + +def test_llm_span_qualifier_sets_attributes(xml2doc, doc2xml): + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="my-custom-model", + prompt=PROMPT, + output_schema=QualifierSchema, + context_formatter=doc2xml, + ) + ) + + doc = xml2doc("Le patient n'a pas de tuberculose.") + + def responder(*, messages, response_format, **_): + assert_response_schema(response_format) + user_prompt = messages[-1]["content"] + assert "tuberculose" in user_prompt + return '{"negation": true, "date": "2024-01-02"}' + + with mock_llm_service(responder=responder): + doc = nlp(doc) + + (ent,) = doc.ents + assert ent._.negation is True + assert ent._.date is not None + + +def test_llm_span_qualifier_async_multiple_spans(xml2doc, doc2xml): + def prompt(context, examples): + assert len(examples) == 0 + messages = [] + system_content = ( + "You are a span classifier.\n" + "Answer with JSON using the keys: biopsy_procedure.\n" + "Here are some examples:\n" + ) + for ex_context, ex_json in examples: + system_content += f"- User: {ex_context}\n" + system_content += f" Assistant: {ex_json}\n" + messages.append({"role": "system", "content": system_content}) + messages.append({"role": "user", "content": context}) + return messages + + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="my-custom-model", + prompt=prompt, + output_schema=QualifierSchema, + context_formatter=doc2xml, + max_concurrent_requests=2, + on_error="raise", + context_getter="words[-2:2]", + ) + ) + + doc = xml2doc( + "Le patient a une tuberculose et une pneumonie." + ) + + def responder(*, messages, response_format, **_): + assert_response_schema(response_format) + content = messages[-1]["content"] + if content == "a une tuberculose et une ": + return '{"negation": true}' + assert content == "et une pneumonie." + return '{"negation": false, "date": "2024-06-01"}' + + with mock_llm_service(responder=responder): + doc = nlp(doc) + + values = { + ent.text: ( + ent._.negation, + ent._.date, + ) + for ent in doc.ents + } + assert values == { + "tuberculose": (True, None), + "pneumonie": (False, datetime.date(2024, 6, 1)), + } -@mark.parametrize("label", ["True", None]) -@mark.parametrize("response_mapping", [{"^True$": "1", "^False$": "0"}, None]) -def test_llm_span_classifier(label, response_mapping): - # Patch AsyncLLM to avoid real API calls - class DummyAsyncLLM: - def __init__(self, *args, **kwargs): - # Initialize the dummy LLM - pass +def test_llm_span_qualifier_multi_text(xml2doc, doc2xml): + nlp = edsnlp.blank("eds") - async def __call__(self, batch_messages): - # Return a dummy label for each message - return [label for _ in batch_messages] + example_docs = edsnlp.data.from_iterable( + [ + "Le patient est atteint de covid.", + ( + "Diagnostic de pneumonie " + "non retenu le 1er juin 2025." + ), + ], + converter="markup", + preset="xml", + bool_attributes=["negation"], + ) - import edsnlp.pipes.qualifiers.llm.llm_qualifier as llm_mod + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="my-custom-model", + prompt=PROMPT, + output_schema=QualifierSchema, + context_formatter=doc2xml, + max_concurrent_requests=3, + on_error="warn", + attributes={"_.negation": ["DIAG", "PAT"], "date": ["DIAG"]}, + examples=example_docs, + use_retriever=True, + max_few_shot_examples=2, + ) + ) + + docs = [ + xml2doc("Le patient n'a pas la covid le 10/11/12."), + xml2doc("Une pneumonie a été diagnostiquée le 15 juin 2024."), + xml2doc("On suspecte une grippe."), + ] - llm_mod.AsyncLLM = DummyAsyncLLM + def responder(*, messages, response_format, **_): + assert_response_schema(response_format) + content = messages[-1]["content"] + if "covid" in content: + return '{"negation": true, "date": "2012-11-10"}' + if "pneumonie" in content: + return '{"negation": false, "date": "2024-06-15"}' + if "patient" in content: + return '{"negation": false, "date": null}' + assert "grippe" in content + return '```json\n{"negation": false}```' + with mock_llm_service(responder=responder): + processed = list(nlp.pipe(docs)) + + results = [ + [(ent.text, ent._.negation, ent._.date) for ent in doc.ents] + for doc in processed + ] + assert results == [ + [("patient", False, None), ("covid", True, datetime.date(2012, 11, 10))], + [("pneumonie", False, datetime.date(2024, 6, 15))], + [("grippe", False, None)], + ] + + +def test_llm_span_qualifier_async_error(xml2doc, doc2xml): nlp = edsnlp.blank("eds") - example = "En RCP du 20/02/2025, patient classé cT3a N0 M0, haut risque. IRM multiparamétrique du 10/02/2025." # noqa: E501 - text, entities = parse_example(example) - doc = nlp(text) - doc.ents = [ - doc.char_span(ent.start_char, ent.end_char, label="date") for ent in entities + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="my-custom-model", + prompt=PROMPT, + output_schema=QualifierSchema, + context_formatter=doc2xml, + max_concurrent_requests=2, + on_error="warn", + ) + ) + + doc = xml2doc( + "Le patient a une tuberculose et une pneumonie." + ) + + def responder(*, response_format, **_): + assert_response_schema(response_format) + raise ValueError("Simulated error") + + with mock_llm_service(responder=responder), pytest.warns( + UserWarning, match="request failed" + ): + doc = nlp(doc) + + for ent in doc.ents: + assert ent._.negation is None + assert ent._.date is None + + +def test_yes_no_schema(xml2doc, doc2xml): + YesBool = Annotated[ + bool, + WithJsonSchema({"type": "string", "enum": ["yes", "no"]}), + BeforeValidator(lambda v: v.lower() in {"yes", "y", "true", "1"}), + PlainSerializer(lambda v: "yes" if v else "no", when_used="json"), ] - # LLMSpanClassifier + nlp = edsnlp.blank("eds") nlp.add_pipe( - LLMSpanClassifier( - nlp=nlp, - name="llm", - model="dummy", - span_getter={"ents": True}, - attributes={"_.test_attr": True}, - context_getter=make_span_context_getter( - context_sents=0, - context_words=(5, 5), + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="my-custom-model", + prompt="""\ +Determine if the highlighted entity is negated. +Answer "yes" or "no".""", + output_schema=YesBool, + attributes="negation", + context_formatter=doc2xml, + on_error="warn", + examples=edsnlp.data.from_iterable( + [ + "Le patient a eu une pneumonie le 1er mai 2024.", + "Le patient n'a pas le covid.", + ], + converter="markup", + preset="xml", + bool_attributes=["negation"], ), - prompt={ - "system_prompt": "You are a medical assistant.", - "user_prompt": "You should help us identify dates in the text.", - "prefix_prompt": "Is '{span}' a date? The text is as follows:\n<<< ", - "suffix_prompt": " >>>", - "examples": [ - ( - "\nIs'07/12/2020' a date. The text is as follows:\n<<< 07/12/2020 : Anapath / biopsies rectales. >>>", # noqa: E501 - "False", - ) + ) + ) + + doc = xml2doc("Le patient n'a pas de tuberculose.") + + def responder(*, messages, **_): + user_prompt = messages[-1]["content"] + # Ideally, LLM service should support scalar schemas + # but this isn't the case yet. + + # assert response_format["type"] == "json_schema" + # payload = response_format["json_schema"] + # assert payload["name"] + # schema = payload["schema"] + # assert schema.get("type") == "string" + # assert schema.get("enum") == ["yes", "no"] + + assert "tuberculose" in user_prompt + return "yes" + + with mock_llm_service(responder=responder): + doc = nlp(doc) + + (ent,) = doc.ents + assert ent._.negation is True + assert ent._.date is None + + +def test_empty_schema(xml2doc, doc2xml): + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="my-custom-model", + prompt="""\ +For the highlighted entity, determine the type of the entity as a single phrase. +""", + attributes="type", + context_formatter=doc2xml, + examples=edsnlp.data.from_iterable( + [ + "On prescrit du paracétamol.", + "Le patient n'a pas le covid.", ], - }, - api_url="https://dummy", - api_params={ + converter="markup", + preset="xml", + ), + ), + ) + doc = xml2doc("Le patient n'a pas de tuberculose.") + + def responder(*, messages, **_): + user_prompt = messages[-1]["content"] + assert "tuberculose" in user_prompt + return "disease" + + with mock_llm_service(responder=responder): + doc = nlp(doc) + (ent,) = doc.ents + assert ent._.type == "disease" + + +def make_prompt_builder(system_prompt): + def prompt(context, retrieved_examples): + messages = [ + {"role": "system", "content": system_prompt}, + ] + for ctx, answer in retrieved_examples: + messages.append({"role": "user", "content": ctx}) + messages.append({"role": "assistant", "content": answer}) + messages.append({"role": "user", "content": context}) + return messages + + return prompt + + +def make_context_formatter(prefix_prompt, suffix_prompt): + def formatter(doc): + (span,) = doc.ents + context_text = doc.text.strip() + if prefix_prompt is None or suffix_prompt is None: + return context_text + prefix = prefix_prompt.format(span=span.text) + suffix = suffix_prompt.format(span=span.text) + return f"{prefix}{context_text}{suffix}" + + return formatter + + +@pytest.mark.parametrize("label", ["True", None]) +@pytest.mark.parametrize("response_mapping", [{"^True$": "1", "^False$": "0"}, None]) +def test_llm_span_qualifier_custom_formatter_sets_attributes( + xml2doc, label, response_mapping +): + system_prompt = ( + "You are a medical assistant, build to help identify dates in the text." + ) + prefix_prompt = "Is '{span}' a date? The text is as follows:\n<<< " + suffix_prompt = " >>>" + example_doc = xml2doc( + "07/12/2020 : Anapath / biopsies rectales." # noqa: E501 + ) + + class ResponseSchema(BaseModel): + test_attr: Optional[str] = None + + @field_validator("test_attr", mode="before") + def apply_mapping(cls, value): + if value is None: + return None + value_str = str(value) + if response_mapping is None: + return value_str + for pattern, mapped in response_mapping.items(): + if re.match(pattern, value_str): + return mapped + return value_str + + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="dummy", + name="llm", + prompt=make_prompt_builder(system_prompt), + output_schema=ResponseSchema, + context_getter="words[-5:5]", + context_formatter=make_context_formatter(prefix_prompt, suffix_prompt), + max_concurrent_requests=1, + max_few_shot_examples=1, + examples=[example_doc], + api_kwargs={ "max_tokens": 10, "temperature": 0.0, "response_format": None, "extra_body": None, }, - response_mapping=response_mapping, - n_concurrent_tasks=1, ) ) - doc = nlp(doc) - # Check that the extension is set and the dummy label is applied + doc = xml2doc( + "En RCP du 20/02/2025, patient classé cT3a N0 M0, haut risque. IRM multiparamétrique du 10/02/2025." # noqa: E501 + ) + qualifier = nlp.get_pipe("llm") + retrieved_examples = qualifier.examples[:1] + assert len(retrieved_examples) == 1 + example_context, example_answer = retrieved_examples[0] + assert ( + example_context + == "Is '07/12/2020' a date? The text is as follows:\n<<< 07/12/2020 : Anapath / biopsies rectales >>>" # noqa: E501 + ) + context_doc = qualifier._build_context_doc(doc.ents[0]) + expected_context = qualifier.context_formatter(context_doc) + expected_messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": example_context}, + {"role": "assistant", "content": example_answer}, + {"role": "user", "content": expected_context}, + ] + assert qualifier.build_prompt(expected_context) == expected_messages + + def responder(*, messages, response_format=None, **_): + if response_format is not None: + assert_response_schema(response_format) + payload = {"test_attr": label} + return json.dumps(payload) + + with mock_llm_service(responder=responder): + doc = nlp(doc) + for span in doc.ents: assert hasattr(span._, "test_attr") - if response_mapping is not None: - if label == "True": + if label == "True": + if response_mapping is None: + assert span._.test_attr == "True" + else: assert span._.test_attr == "1" - elif label is None: - assert span._.test_attr is None + else: + assert span._.test_attr is None - if response_mapping is None: - if label == "True": - assert span._.test_attr == label - elif label is None: - assert span._.test_attr is None + assert qualifier.attributes == {"_.test_attr": True} - assert nlp.get_pipe("llm").attributes == {"_.test_attr": True} - -@mark.parametrize( - "prefix_prompt,suffix_prompt", - [("Is '{span}' a date? The text is as follows:\n<<< ", " >>>"), (None, None)], +@pytest.mark.parametrize( + ("prefix_prompt", "suffix_prompt"), + [ + ("Is '{span}' a date? The text is as follows:\n<<< ", " >>>"), + (None, None), + ], ) -def test_llm_span_classifier_preprocess(prefix_prompt, suffix_prompt): - # Patch AsyncLLM to avoid real API calls - class DummyAsyncLLM: - def __init__(self, *args, **kwargs): - # Initialize the dummy LLM - pass +def test_llm_span_qualifier_custom_formatter_prompt( + xml2doc, prefix_prompt, suffix_prompt +): + system_prompt = ( + "You are a medical assistant, build to help identify dates in the text." + ) + example_doc = xml2doc( + "07/12/2020 : Anapath / biopsies rectales." # noqa: E501 + ) - async def __call__(self, batch_messages): - # Return a dummy label for each message - return ["True" for _ in batch_messages] + class ResponseSchema(BaseModel): + test_attr: Optional[str] = None - import edsnlp.pipes.qualifiers.llm.llm_qualifier as llm_mod + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.llm_span_qualifier( + api_url="http://localhost:8080/v1", + model="dummy", + name="llm", + prompt=make_prompt_builder(system_prompt), + output_schema=ResponseSchema, + context_getter="words[-5:5]", + context_formatter=make_context_formatter(prefix_prompt, suffix_prompt), + max_concurrent_requests=1, + max_few_shot_examples=1, + examples=[example_doc], + api_kwargs={ + "max_tokens": 10, + "temperature": 0.0, + "response_format": None, + "extra_body": None, + }, + ) + ) - llm_mod.AsyncLLM = DummyAsyncLLM + doc = xml2doc( + "En RCP du 20/02/2025, patient classé cT3a N0 M0, haut risque. IRM multiparamétrique du 10/02/2025." # noqa: E501 + ) - nlp = edsnlp.blank("eds") - example = "En RCP du 20/02/2025, patient classé cT3a N0 M0, haut risque. IRM multiparamétrique du 10/02/2025." # noqa: E501 - text, entities = parse_example(example) - doc = nlp(text) - doc.ents = [ - doc.char_span(ent.start_char, ent.end_char, label="date") for ent in entities + qualifier = nlp.get_pipe("llm") + retrieved_examples = qualifier.examples[:1] + assert len(retrieved_examples) == 1 + example_context, example_answer = retrieved_examples[0] + expected_example_context = ( + "Is '07/12/2020' a date? The text is as follows:\n<<< 07/12/2020 : Anapath / biopsies rectales >>>" # noqa: E501 + if prefix_prompt is not None + else "07/12/2020 : Anapath / biopsies rectales" + ) + assert example_context == expected_example_context + span = doc.ents[0] + context_doc = qualifier._build_context_doc(span) + context_text = qualifier.context_formatter(context_doc) + expected_messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": example_context}, + {"role": "assistant", "content": example_answer}, + {"role": "user", "content": context_text}, ] - system_prompt = "You are a medical assistant." - user_prompt = "You should help us identify dates in the text." - examples = [ - ( - "\nIs'07/12/2020' a date. The text is as follows:\n<<< 07/12/2020 : Anapath / biopsies rectales. >>>", # noqa: E501 - "False", - ) - ] + expected_context = ( + "Is '20/02/2025' a date? The text is as follows:\n<<< En RCP du 20/02/2025, patient classé cT3 >>>" # noqa: E501 + if prefix_prompt is not None + else "En RCP du 20/02/2025, patient classé cT3" + ) + assert context_text == expected_context + assert qualifier.build_prompt(context_text) == expected_messages - # LLMSpanClassifier - llm = LLMSpanClassifier( - nlp=nlp, - name="llm", - model="dummy", - span_getter={"ents": True}, - attributes={"_.test_attr": True}, - context_getter=make_span_context_getter( - context_sents=0, - context_words=(5, 5), - ), - prompt={ - "system_prompt": system_prompt, - "user_prompt": user_prompt, - "prefix_prompt": prefix_prompt, - "suffix_prompt": suffix_prompt, - "examples": examples, - }, - api_url="https://dummy", - api_params={ - "max_tokens": 10, - "temperature": 0.0, - "response_format": None, - "extra_body": None, - }, - response_mapping=None, - n_concurrent_tasks=1, - ) - - inputs = llm.preprocess(doc) - if (prefix_prompt is not None) and (suffix_prompt is not None): - assert inputs["doc_batch_messages"][0] == [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - { - "role": "user", - "content": examples[0][0], - }, - {"role": "assistant", "content": examples[0][1]}, - { - "role": "user", - "content": "Is '20/02/2025' a date? The text is as follows:\n<<< En RCP du 20/02/2025, patient classé cT3 >>>", # noqa: E501 - }, - ] - else: - assert inputs["doc_batch_messages"][0] == [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - { - "role": "user", - "content": examples[0][0], - }, - {"role": "assistant", "content": examples[0][1]}, - { - "role": "user", - "content": "En RCP du 20/02/2025, patient classé cT3", # noqa: E501 - }, - ] + def responder(*, messages, response_format=None, **_): + if response_format is not None: + assert_response_schema(response_format) + return json.dumps({"test_attr": "True"}) + + with mock_llm_service(responder=responder): + doc = nlp(doc) + + span = doc.ents[0] + assert span._.test_attr == "True" + assert qualifier.attributes == {"_.test_attr": True} diff --git a/tests/pipelines/llm/test_llm_utils.py b/tests/pipelines/llm/test_llm_utils.py deleted file mode 100644 index b9d936e587..0000000000 --- a/tests/pipelines/llm/test_llm_utils.py +++ /dev/null @@ -1,213 +0,0 @@ -from typing import Optional - -import httpx -import respx -from openai.types.chat.chat_completion import ChatCompletion -from pytest import mark - -from edsnlp.pipes.qualifiers.llm.llm_utils import ( - AsyncLLM, - create_prompt_messages, - parse_json_response, -) -from edsnlp.utils.asynchronous import run_async - - -@mark.parametrize("n_concurrent_tasks", [1, 2]) -def test_async_llm(n_concurrent_tasks): - api_url = "http://localhost:8000/v1/" - suffix_url = "chat/completions" - llm_api = AsyncLLM(n_concurrent_tasks=n_concurrent_tasks, api_url=api_url) - - with respx.mock: - respx.post(api_url + suffix_url).mock( - side_effect=[ - httpx.Response( - 200, json={"choices": [{"message": {"content": "positive"}}]} - ), - httpx.Response( - 200, json={"choices": [{"message": {"content": "negative"}}]} - ), - ] - ) - - response = run_async( - llm_api( - batch_messages=[ - [ - {"role": "user", "content": "your prompt here"}, - {"role": "assistant", "content": "Hello!"}, - {"role": "user", "content": "your second prompt here"}, - ], - [{"role": "user", "content": "your second prompt here"}], - ] - ) - ) - assert response == ["positive", "negative"] - - -def test_create_prompt_messages(): - messages = create_prompt_messages( - system_prompt="Hello", - user_prompt="Hi", - examples=[("One", "1")], - final_user_prompt="What is your name?", - ) - messages_expected = [ - {"role": "system", "content": "Hello"}, - {"role": "user", "content": "Hi"}, - {"role": "user", "content": "One"}, - {"role": "assistant", "content": "1"}, - {"role": "user", "content": "What is your name?"}, - ] - assert messages == messages_expected - messages2 = create_prompt_messages( - system_prompt="Hello", - user_prompt=None, - examples=[("One", "1")], - final_user_prompt="What is your name?", - ) - messages_expected2 = [ - {"role": "system", "content": "Hello"}, - {"role": "user", "content": "One"}, - {"role": "assistant", "content": "1"}, - {"role": "user", "content": "What is your name?"}, - ] - assert messages2 == messages_expected2 - - -def create_fake_chat_completion( - choices: int = 1, content: Optional[str] = '{"biopsy":false}' -): - fake_response_data = { - "id": "chatcmpl-fake123", - "object": "chat.completion", - "created": 1699999999, - "model": "toto", - "choices": [ - { - "index": i, - "message": {"role": "assistant", "content": content}, - "finish_reason": "stop", - } - for i in range(choices) - ], - "usage": {"prompt_tokens": 10, "completion_tokens": 15, "total_tokens": 25}, - } - - # Create the ChatCompletion object - fake_completion = ChatCompletion.model_validate(fake_response_data) - return fake_completion - - -def test_parse_json_response(): - response = create_fake_chat_completion() - response_format = { - "type": "json_schema", - "json_schema": { - "name": "DateModel", - "schema": { - "properties": {"biopsy": {"title": "Biopsy", "type": "boolean"}}, - "required": ["biopsy"], - "title": "DateModel", - "type": "object", - }, - }, - } - - llm = AsyncLLM(n_concurrent_tasks=1) - parsed_response = llm.parse_messages(response, response_format) - assert parsed_response == {"biopsy": False} - - response = create_fake_chat_completion(content=None) - parsed_response = llm.parse_messages(response, response_format=response_format) - assert parsed_response == {} - - parsed_response = llm.parse_messages(response, response_format=None) - assert parsed_response == response - - parsed_response = llm.parse_messages(None, response_format=None) - assert parsed_response is None - - -@mark.parametrize("n_completions", [1, 2]) -def test_exception_handling(n_completions): - api_url = "http://localhost:8000/v1/" - suffix_url = "chat/completions" - llm_api = AsyncLLM( - n_concurrent_tasks=1, api_url=api_url, n_completions=n_completions - ) - - with respx.mock: - respx.post(api_url + suffix_url).mock( - side_effect=[ - httpx.Response(404, json={"choices": [{}]}), - ] - ) - - response = run_async( - llm_api( - batch_messages=[ - [{"role": "user", "content": "your prompt here"}], - ] - ) - ) - if n_completions == 1: - assert response == [""] - else: - assert response == [[""] * n_completions] - - -@mark.parametrize("errors", ["ignore", "raw"]) -def test_json_decode_error(errors): - raw_response = '{"biopsy";false}' - response_format = { - "type": "json_schema", - "json_schema": { - "name": "DateModel", - "schema": { - "properties": {"biopsy": {"title": "Biopsy", "type": "boolean"}}, - "required": ["biopsy"], - "title": "DateModel", - "type": "object", - }, - }, - } - - response = parse_json_response(raw_response, response_format, errors=errors) - if errors == "ignore": - assert response == {} - else: - assert response == raw_response - - -def test_decode_no_format(): - raw_response = '{"biopsy":false}' - - response = parse_json_response(raw_response, response_format=None) - - assert response == raw_response - - -def test_multiple_completions(n_completions=2): - api_url = "http://localhost:8000/v1/" - suffix_url = "chat/completions" - llm_api = AsyncLLM( - n_concurrent_tasks=1, api_url=api_url, n_completions=n_completions - ) - completion = create_fake_chat_completion(n_completions, content="false") - with respx.mock: - respx.post(api_url + suffix_url).mock( - side_effect=[ - httpx.Response(200, json=completion.model_dump()), - ] - ) - - response = run_async( - llm_api( - batch_messages=[ - [{"role": "user", "content": "your prompt here"}], - ] - ) - ) - assert response == [["false", "false"]]