|
14 | 14 | # limitations under the License.
|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
| 17 | +from collections import defaultdict |
17 | 18 | from typing import (
|
18 | 19 | Any,
|
19 | 20 | ClassVar,
|
@@ -336,37 +337,40 @@ def _get_connections(self) -> list[ConnectionDefinition]:
|
336 | 337 | return connections
|
337 | 338 |
|
338 | 339 | def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
|
339 |
| - run_params = {} |
340 |
| - if self.lexical_graph_config: |
341 |
| - run_params["extractor"] = { |
342 |
| - "lexical_graph_config": self.lexical_graph_config, |
343 |
| - } |
344 |
| - run_params["writer"] = { |
345 |
| - "lexical_graph_config": self.lexical_graph_config, |
346 |
| - } |
347 |
| - run_params["pruner"] = { |
348 |
| - "lexical_graph_config": self.lexical_graph_config, |
349 |
| - } |
350 | 340 | text = user_input.get("text")
|
351 | 341 | file_path = user_input.get("file_path")
|
352 |
| - if not ((text is None) ^ (file_path is None)): |
353 |
| - # exactly one of text or user_input must be set |
| 342 | + if text is None and file_path is None: |
| 343 | + # user must provide either text or file_path or both |
354 | 344 | raise PipelineDefinitionError(
|
355 |
| - "Use either 'text' (when from_pdf=False) or 'file_path' (when from_pdf=True) argument." |
| 345 | + "At least one of `text` (when from_pdf=False) or `file_path` (when from_pdf=True) argument must be provided." |
356 | 346 | )
|
| 347 | + run_params: dict[str, dict[str, Any]] = defaultdict(dict) |
| 348 | + if self.lexical_graph_config: |
| 349 | + run_params["extractor"]["lexical_graph_config"] = self.lexical_graph_config |
| 350 | + run_params["writer"]["lexical_graph_config"] = self.lexical_graph_config |
| 351 | + run_params["pruner"]["lexical_graph_config"] = self.lexical_graph_config |
357 | 352 | if self.from_pdf:
|
358 | 353 | if not file_path:
|
359 | 354 | raise PipelineDefinitionError(
|
360 | 355 | "Expected 'file_path' argument when 'from_pdf' is True."
|
361 | 356 | )
|
362 |
| - run_params["pdf_loader"] = {"filepath": file_path} |
| 357 | + run_params["pdf_loader"]["filepath"] = file_path |
| 358 | + run_params["pdf_loader"]["metadata"] = user_input.get("document_metadata") |
363 | 359 | else:
|
364 | 360 | if not text:
|
365 | 361 | raise PipelineDefinitionError(
|
366 | 362 | "Expected 'text' argument when 'from_pdf' is False."
|
367 | 363 | )
|
368 |
| - run_params["splitter"] = {"text": text} |
| 364 | + run_params["splitter"]["text"] = text |
369 | 365 | # Add full text to schema component for automatic schema extraction
|
370 | 366 | if not self.has_user_provided_schema():
|
371 |
| - run_params["schema"] = {"text": text} |
| 367 | + run_params["schema"]["text"] = text |
| 368 | + run_params["extractor"]["document_info"] = dict( |
| 369 | + path=user_input.get( |
| 370 | + "file_path", |
| 371 | + ) |
| 372 | + or "document.txt", |
| 373 | + metadata=user_input.get("document_metadata"), |
| 374 | + document_type="inline_text", |
| 375 | + ) |
372 | 376 | return run_params
|
0 commit comments