diff --git a/src/praisonai-agents/praisonaiagents/tools/__init__.py b/src/praisonai-agents/praisonaiagents/tools/__init__.py index 95f1f1bcf..e35a6c179 100644 --- a/src/praisonai-agents/praisonaiagents/tools/__init__.py +++ b/src/praisonai-agents/praisonaiagents/tools/__init__.py @@ -2,6 +2,15 @@ from importlib import import_module from typing import Any +# Import advanced tools functionality +from .advanced_tools import ( + tool, cache, external, user_input, + Field, InputGroup, Choice, Range, Pattern, + ToolContext, Hook, CacheConfig, ExternalConfig, Priority, + set_global_hooks, clear_global_hooks, register_external_handler, + invalidate_cache, clear_all_caches, get_cache_stats +) + # Map of function names to their module and class (if any) TOOL_MAPPINGS = { # Direct functions @@ -199,4 +208,11 @@ def __getattr__(name: str) -> Any: method = getattr(_instances[class_name], name) return method -__all__ = list(TOOL_MAPPINGS.keys()) \ No newline at end of file +__all__ = list(TOOL_MAPPINGS.keys()) + [ + # Advanced tools functionality + 'tool', 'cache', 'external', 'user_input', + 'Field', 'InputGroup', 'Choice', 'Range', 'Pattern', + 'ToolContext', 'Hook', 'CacheConfig', 'ExternalConfig', 'Priority', + 'set_global_hooks', 'clear_global_hooks', 'register_external_handler', + 'invalidate_cache', 'clear_all_caches', 'get_cache_stats' +] \ No newline at end of file diff --git a/src/praisonai-agents/praisonaiagents/tools/advanced_tools.py b/src/praisonai-agents/praisonaiagents/tools/advanced_tools.py new file mode 100644 index 000000000..98c3a9af6 --- /dev/null +++ b/src/praisonai-agents/praisonaiagents/tools/advanced_tools.py @@ -0,0 +1,492 @@ +"""Advanced tools framework for PraisonAI Agents. + +This module provides advanced tool decorators and functionality including: +- Pre/Post execution hooks +- Tool-level caching with TTL +- External execution markers +- Structured user input fields + +Maintains backward compatibility with existing tools. +""" + +import asyncio +import functools +import inspect +import time +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from dataclasses import dataclass, field +from enum import Enum + + +class Priority(Enum): + """Hook execution priority levels.""" + HIGHEST = 1 + HIGH = 2 + MEDIUM = 3 + LOW = 4 + LOWEST = 5 + + +@dataclass +class ToolContext: + """Context object passed to hooks and handlers.""" + tool_name: str + args: tuple = field(default_factory=tuple) + kwargs: dict = field(default_factory=dict) + result: Any = None + error: Optional[Exception] = None + agent: Any = None # Will be set by agent at runtime + execution_time: float = 0.0 + metadata: dict = field(default_factory=dict) + start_time: float = field(default_factory=time.time) + + def set_result(self, result: Any): + """Set the tool result and calculate execution time.""" + self.result = result + self.execution_time = time.time() - self.start_time + + def set_error(self, error: Exception): + """Set the tool error and calculate execution time.""" + self.error = error + self.execution_time = time.time() - self.start_time + + +@dataclass +class Hook: + """Represents a hook function with priority.""" + func: Callable[[ToolContext], Any] + priority: Priority = Priority.MEDIUM + + def __call__(self, context: ToolContext) -> Any: + return self.func(context) + + +@dataclass +class CacheConfig: + """Cache configuration for tools.""" + enabled: bool = True + ttl: int = 300 # seconds + backend: str = 'memory' # 'memory', 'redis', etc. + key_func: Optional[Callable] = None + condition: Optional[Callable] = None + tags: List[str] = field(default_factory=list) + + +@dataclass +class ExternalConfig: + """External execution configuration.""" + enabled: bool = True + executor: str = 'default' + requirements: List[str] = field(default_factory=list) + estimated_time: int = 60 # seconds + when: Optional[Callable] = None # Conditional external execution + type: str = 'generic' # 'human_approval', 'webhook', 'generic' + endpoint: Optional[str] = None + auth_token: Optional[str] = None + + +@dataclass +class Field: + """Structured input field definition.""" + name: str + type: Any = str + description: str = "" + required: bool = True + default: Any = None + secret: bool = False + + def __post_init__(self): + if self.default is not None: + self.required = False + + +@dataclass +class InputGroup: + """Group of related input fields.""" + name: str + fields: List[Field] + + def __init__(self, name: str, *fields: Field): + self.name = name + self.fields = list(fields) + + +class Choice: + """Choice field type.""" + def __init__(self, choices: List[str]): + self.choices = choices + + +class Range: + """Range field type.""" + def __init__(self, min_val: float, max_val: float): + self.min = min_val + self.max = max_val + + +class Pattern: + """Pattern validation field type.""" + def __init__(self, pattern: str): + self.pattern = pattern + + +# Global hooks registry +_global_hooks = { + 'before': [], + 'after': [] +} + +# Cache storage +_cache_storage = {} +_cache_metadata = {} # For TTL tracking + +# External handlers registry +_external_handlers = {} + + +def set_global_hooks(before: Optional[Callable] = None, after: Optional[Callable] = None): + """Set global hooks that apply to all tools.""" + if before: + _global_hooks['before'].append(Hook(before)) + if after: + _global_hooks['after'].append(Hook(after)) + + +def clear_global_hooks(): + """Clear all global hooks.""" + _global_hooks['before'].clear() + _global_hooks['after'].clear() + + +def register_external_handler(name: str, handler: Callable): + """Register an external execution handler.""" + _external_handlers[name] = handler + + +def invalidate_cache(tags: Optional[List[str]] = None, tool_name: Optional[str] = None): + """Invalidate cache entries by tags or tool name.""" + if tool_name: + # Remove all cache entries for a specific tool + keys_to_remove = [k for k in _cache_storage.keys() if k.startswith(f"{tool_name}:")] + for key in keys_to_remove: + del _cache_storage[key] + if key in _cache_metadata: + del _cache_metadata[key] + + if tags: + # Remove cache entries with specific tags + keys_to_remove = [] + for key, metadata in _cache_metadata.items(): + if any(tag in metadata.get('tags', []) for tag in tags): + keys_to_remove.append(key) + + for key in keys_to_remove: + if key in _cache_storage: + del _cache_storage[key] + del _cache_metadata[key] + + +def clear_all_caches(): + """Clear all cache entries.""" + _cache_storage.clear() + _cache_metadata.clear() + + +def get_cache_stats(): + """Get cache statistics.""" + return { + 'total_entries': len(_cache_storage), + 'hits': sum(m.get('hits', 0) for m in _cache_metadata.values()), + 'misses': sum(m.get('misses', 0) for m in _cache_metadata.values()) + } + + +def _generate_cache_key(tool_name: str, args: tuple, kwargs: dict, key_func: Optional[Callable] = None) -> str: + """Generate cache key for tool execution.""" + if key_func: + return f"{tool_name}:{key_func(*args, **kwargs)}" + else: + # Simple hash-based key + import hashlib + content = f"{args}:{sorted(kwargs.items())}" + hash_key = hashlib.md5(content.encode()).hexdigest()[:16] + return f"{tool_name}:{hash_key}" + + +def _is_cache_valid(key: str) -> bool: + """Check if cache entry is still valid (not expired).""" + if key not in _cache_metadata: + return False + + metadata = _cache_metadata[key] + if 'expires_at' in metadata: + return time.time() < metadata['expires_at'] + return True + + +def _execute_hooks(hooks: List[Hook], context: ToolContext) -> bool: + """Execute hooks in priority order. Returns False if execution should be stopped.""" + # Sort hooks by priority + sorted_hooks = sorted(hooks, key=lambda h: h.priority.value) + + for hook in sorted_hooks: + try: + result = hook(context) + # If hook returns False, stop execution + if result is False: + return False + except Exception as e: + # Log hook error but continue execution + print(f"Hook error in {hook.func.__name__}: {e}") + + return True + + +def _handle_external_execution(func: Callable, context: ToolContext, external_config: ExternalConfig) -> Any: + """Handle external execution of a tool.""" + # Check if external execution is conditional + if external_config.when and not external_config.when(*context.args, **context.kwargs): + # Execute normally + return func(*context.args, **context.kwargs) + + # Look for registered external handler + handler = _external_handlers.get(external_config.executor) + if handler: + return handler(func, context, external_config) + + # Default external handling + if external_config.type == 'human_approval': + response = input(f"External approval required for {context.tool_name}. Proceed? (y/n): ") + if response.lower() != 'y': + raise Exception("External execution denied by user") + + # Execute the tool normally for now + return func(*context.args, **context.kwargs) + + +def tool( + name: Optional[str] = None, + description: Optional[str] = None, + before: Union[Callable, List[Union[Callable, Tuple[Callable, Priority]]], None] = None, + after: Union[Callable, List[Union[Callable, Tuple[Callable, Priority]]], None] = None, + cache: Union[bool, CacheConfig, None] = None, + external: Union[bool, ExternalConfig, None] = None, + inputs: Optional[List[Union[Field, InputGroup]]] = None, + require_approval: Optional[str] = None, + risk_level: Optional[str] = None +): + """ + Advanced tool decorator with hooks, caching, external execution, and input validation. + + Args: + name: Tool name (defaults to function name) + description: Tool description (defaults to function docstring) + before: Pre-execution hooks + after: Post-execution hooks + cache: Caching configuration + external: External execution configuration + inputs: Structured input field definitions + require_approval: Backward compatibility with existing approval system + risk_level: Risk level for approval system + """ + def decorator(func: Callable) -> Callable: + tool_name = name or func.__name__ + tool_description = description or func.__doc__ or "No description available" + + # Normalize hooks + before_hooks = [] + after_hooks = [] + + if before: + if callable(before): + before_hooks.append(Hook(before)) + elif isinstance(before, list): + for hook in before: + if callable(hook): + before_hooks.append(Hook(hook)) + elif isinstance(hook, tuple) and len(hook) == 2: + before_hooks.append(Hook(hook[0], hook[1])) + + if after: + if callable(after): + after_hooks.append(Hook(after)) + elif isinstance(after, list): + for hook in after: + if callable(hook): + after_hooks.append(Hook(hook)) + elif isinstance(hook, tuple) and len(hook) == 2: + after_hooks.append(Hook(hook[0], hook[1])) + + # Normalize cache config + cache_config = None + if cache is True: + cache_config = CacheConfig() + elif isinstance(cache, CacheConfig): + cache_config = cache + elif isinstance(cache, dict): + cache_config = CacheConfig(**cache) + + # Normalize external config + external_config = None + if external is True: + external_config = ExternalConfig() + elif isinstance(external, ExternalConfig): + external_config = external + elif isinstance(external, dict): + external_config = ExternalConfig(**external) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + context = ToolContext( + tool_name=tool_name, + args=args, + kwargs=kwargs + ) + + try: + # Execute global before hooks + if not _execute_hooks(_global_hooks['before'], context): + return context.result + + # Execute tool-specific before hooks + if not _execute_hooks(before_hooks, context): + return context.result + + # Check cache + if cache_config and cache_config.enabled: + cache_key = _generate_cache_key(tool_name, args, kwargs, cache_config.key_func) + + if cache_key in _cache_storage and _is_cache_valid(cache_key): + # Cache hit + result = _cache_storage[cache_key] + _cache_metadata[cache_key]['hits'] = _cache_metadata[cache_key].get('hits', 0) + 1 + context.set_result(result) + + # Execute after hooks with cached result + _execute_hooks(after_hooks, context) + _execute_hooks(_global_hooks['after'], context) + + return result + else: + # Cache miss + if cache_key in _cache_metadata: + _cache_metadata[cache_key]['misses'] = _cache_metadata[cache_key].get('misses', 0) + 1 + else: + _cache_metadata[cache_key] = {'misses': 1, 'hits': 0} + + # Execute tool + if external_config and external_config.enabled: + result = _handle_external_execution(func, context, external_config) + else: + result = func(*args, **kwargs) + + context.set_result(result) + + # Store in cache if configured + if cache_config and cache_config.enabled: + # Check condition if specified + if not cache_config.condition or cache_config.condition(result): + cache_key = _generate_cache_key(tool_name, args, kwargs, cache_config.key_func) + _cache_storage[cache_key] = result + + # Set expiration + if cache_key not in _cache_metadata: + _cache_metadata[cache_key] = {'hits': 0, 'misses': 0} + + _cache_metadata[cache_key].update({ + 'expires_at': time.time() + cache_config.ttl, + 'tags': cache_config.tags + }) + + # Execute after hooks + _execute_hooks(after_hooks, context) + _execute_hooks(_global_hooks['after'], context) + + return result + + except Exception as e: + context.set_error(e) + + # Execute error handling hooks + _execute_hooks(after_hooks, context) + _execute_hooks(_global_hooks['after'], context) + + # If error was cleared by hooks, return the result + if context.error is None: + return context.result + + # Re-raise the error + raise e + + # Add metadata to the function + wrapper._tool_metadata = { + 'name': tool_name, + 'description': tool_description, + 'cache_config': cache_config, + 'external_config': external_config, + 'inputs': inputs, + 'before_hooks': before_hooks, + 'after_hooks': after_hooks, + 'require_approval': require_approval, + 'risk_level': risk_level + } + + # Backward compatibility with existing approval system + if require_approval or risk_level: + try: + from .approval import require_approval as approval_decorator + if risk_level: + wrapper = approval_decorator(risk_level)(wrapper) + elif require_approval: + wrapper = approval_decorator(require_approval)(wrapper) + except ImportError: + # Approval system not available + pass + + return wrapper + + return decorator + + +# Convenience decorators for common patterns +def cache(ttl: int = 300, backend: str = 'memory', key: Optional[Callable] = None, + condition: Optional[Callable] = None, tags: Optional[List[str]] = None): + """Convenience decorator for caching.""" + config = CacheConfig( + ttl=ttl, + backend=backend, + key_func=key, + condition=condition, + tags=tags or [] + ) + return lambda func: tool(cache=config)(func) + + +def external(executor: str = 'default', requirements: Optional[List[str]] = None, + estimated_time: int = 60, when: Optional[Callable] = None, + type: str = 'generic', endpoint: Optional[str] = None): + """Convenience decorator for external execution.""" + config = ExternalConfig( + executor=executor, + requirements=requirements or [], + estimated_time=estimated_time, + when=when, + type=type, + endpoint=endpoint + ) + return lambda func: tool(external=config)(func) + + +def user_input(*fields: Union[Field, InputGroup]): + """Convenience decorator for user input validation.""" + return lambda func: tool(inputs=list(fields))(func) + + +# Export all public classes and functions +__all__ = [ + 'tool', 'cache', 'external', 'user_input', + 'Field', 'InputGroup', 'Choice', 'Range', 'Pattern', + 'ToolContext', 'Hook', 'CacheConfig', 'ExternalConfig', 'Priority', + 'set_global_hooks', 'clear_global_hooks', 'register_external_handler', + 'invalidate_cache', 'clear_all_caches', 'get_cache_stats' +] \ No newline at end of file diff --git a/src/praisonai-agents/praisonaiagents/tools/examples.py b/src/praisonai-agents/praisonaiagents/tools/examples.py new file mode 100644 index 000000000..8fda1987a --- /dev/null +++ b/src/praisonai-agents/praisonaiagents/tools/examples.py @@ -0,0 +1,346 @@ +""" +Examples of advanced tools functionality for PraisonAI Agents. + +This module demonstrates how to use the new advanced tool features: +- Pre/Post execution hooks +- Tool-level caching with TTL +- External execution markers +- Structured user input fields +""" + +from praisonaiagents.tools import tool, cache, external, user_input +from praisonaiagents.tools import Field, InputGroup, Choice, Range, Pattern +from praisonaiagents.tools import ToolContext, set_global_hooks, register_external_handler + +import time +import requests +from typing import List, Dict, Any + + +# Example 1: Pre/Post Execution Hooks +def log_start(context: ToolContext): + """Log when a tool starts executing.""" + print(f"๐Ÿš€ Starting {context.tool_name} with args: {context.args}") + +def log_end(context: ToolContext): + """Log when a tool finishes executing.""" + if context.error: + print(f"โŒ {context.tool_name} failed: {context.error}") + else: + print(f"โœ… {context.tool_name} completed in {context.execution_time:.2f}s") + +@tool( + before=log_start, + after=log_end +) +def calculate_metrics(data: List[float]) -> Dict[str, float]: + """Calculate basic statistics for a list of numbers.""" + if not data: + return {"error": "No data provided"} + + return { + "mean": sum(data) / len(data), + "min": min(data), + "max": max(data), + "count": len(data) + } + + +# Example 2: Multiple hooks with priority +def validate_input(context: ToolContext): + """Validate input data.""" + if context.args and len(context.args[0]) == 0: + raise ValueError("Input data cannot be empty") + +from praisonaiagents.tools import Priority + +@tool( + before=[ + (validate_input, Priority.HIGHEST), # Runs first + (log_start, Priority.MEDIUM) # Runs second + ], + after=log_end +) +def process_data(input_data: str) -> str: + """Process input data with validation.""" + return input_data.upper() + + +# Example 3: Error handling hook +def error_handler(context: ToolContext): + """Handle tool errors gracefully.""" + if context.error: + print(f"๐Ÿ”ง Handling error in {context.tool_name}: {context.error}") + # Can modify the error or suppress it + if "network" in str(context.error).lower(): + context.error = None + context.result = {"status": "offline", "message": "Network unavailable"} + +@tool(after=error_handler) +def risky_operation(fail: bool = False): + """An operation that might fail.""" + if fail: + raise Exception("Network connection failed") + return {"status": "success"} + + +# Example 4: Simple caching +@tool +@cache(ttl=300) # 5 minutes +def fetch_weather(city: str) -> Dict[str, Any]: + """Fetch weather data with caching.""" + # Simulate API call + time.sleep(1) # Simulate network delay + return { + "city": city, + "temperature": 22, + "condition": "sunny", + "timestamp": time.time() + } + + +# Example 5: Advanced caching with custom key and condition +@tool +@cache( + ttl=3600, # 1 hour + key=lambda city, date: f"{city}:{date}", # Custom cache key + condition=lambda result: result.get('status') == 'success', # Only cache successful results + tags=['weather', 'historical'] +) +def get_historical_weather(city: str, date: str) -> Dict[str, Any]: + """Get historical weather data with advanced caching.""" + # Simulate API call + return { + "city": city, + "date": date, + "temperature": 18, + "status": "success" + } + + +# Example 6: External execution markers +@tool +@external +def run_on_gpu(model_path: str, data: List[float]) -> Dict[str, Any]: + """Run model inference on GPU (marked for external execution).""" + # This would pause execution and return control to handler + return {"predictions": [x * 2 for x in data]} + + +# Example 7: External with metadata +@tool +@external( + executor="gpu_cluster", + requirements=["cuda>=11.0", "torch>=2.0"], + estimated_time=300 # 5 minutes +) +def train_model(dataset: str, hyperparams: Dict[str, Any]) -> Dict[str, Any]: + """Train a model on GPU cluster.""" + return {"model_id": "model_123", "accuracy": 0.95} + + +# Example 8: Conditional external execution +@tool +@external(when=lambda args: len(args[0]) > 1000) # Only external if data is large +def process_large_data(data: List[Any], threshold: int = 1000) -> Dict[str, Any]: + """Process data, using external execution for large datasets.""" + return {"processed_count": len(data), "external": len(data) > threshold} + + +# Example 9: Structured user input fields +@tool +@user_input( + Field(name="confirm", type=bool, description="Proceed with deletion?"), + Field(name="reason", type=str, description="Reason for deletion", required=False) +) +def delete_records(confirm: bool, reason: str = None) -> Dict[str, Any]: + """Delete records with user confirmation.""" + if not confirm: + return {"status": "cancelled", "message": "Deletion cancelled by user"} + + return { + "status": "deleted", + "reason": reason or "No reason provided", + "deleted_count": 10 + } + + +# Example 10: Advanced field types +@tool +@user_input( + Field( + name="priority", + type=Choice(["low", "medium", "high"]), + description="Task priority", + default="medium" + ), + Field( + name="budget", + type=Range(min=0, max=10000), + description="Budget in USD" + ), + Field( + name="email", + type=Pattern(r"^[\w\.-]+@[\w\.-]+\.\w+$"), + description="Contact email" + ) +) +def create_project(priority: str, budget: float, email: str) -> Dict[str, Any]: + """Create a new project with validated inputs.""" + return { + "project_id": "proj_123", + "priority": priority, + "budget": budget, + "contact": email, + "status": "created" + } + + +# Example 11: Input groups +@tool +@user_input( + InputGroup( + "Personal Information", + Field(name="first_name", type=str), + Field(name="last_name", type=str), + Field(name="age", type=int, required=False) + ), + InputGroup( + "Preferences", + Field(name="newsletter", type=bool, default=True), + Field(name="language", type=Choice(["en", "es", "fr"])) + ) +) +def register_user(**kwargs) -> Dict[str, Any]: + """Register a new user with grouped input fields.""" + return { + "user_id": "user_123", + "profile": kwargs, + "status": "registered" + } + + +# Example 12: Integration with existing approval system +try: + from praisonaiagents.tools.approval import require_approval + + @require_approval(risk_level="high") + @tool( + before=validate_input, + after=log_end + ) + def delete_production_data(table: str) -> Dict[str, Any]: + """Delete production data with approval and hooks.""" + return {"table": table, "status": "deleted", "rows": 1000} + +except ImportError: + # Approval system not available, create simple version + @tool( + before=validate_input, + after=log_end + ) + def delete_production_data(table: str) -> Dict[str, Any]: + """Delete production data with hooks (approval not available).""" + return {"table": table, "status": "deleted", "rows": 1000} + + +# Example 13: Global hooks setup +def global_logger(context: ToolContext): + """Global logging for all tools.""" + print(f"๐Ÿ”ง Tool executed: {context.tool_name}") + +def global_metrics(context: ToolContext): + """Global metrics collection.""" + # In a real implementation, this would send to a metrics system + print(f"๐Ÿ“Š Metrics: {context.tool_name} took {context.execution_time:.2f}s") + +# Set up global hooks +set_global_hooks( + before=global_logger, + after=global_metrics +) + + +# Example 14: External handler registration +async def gpu_cluster_handler(func, context: ToolContext, external_config): + """Handle GPU cluster execution.""" + print(f"๐Ÿ“ก Submitting {context.tool_name} to GPU cluster...") + + # Simulate external execution + import asyncio + await asyncio.sleep(1) # Simulate processing time + + # Execute the function normally for this example + result = func(*context.args, **context.kwargs) + print(f"๐Ÿ GPU cluster execution completed") + return result + +# Register the external handler +register_external_handler("gpu_cluster", gpu_cluster_handler) + + +# Example 15: Comprehensive tool with all features +@tool( + name="comprehensive_analysis", + description="A comprehensive data analysis tool demonstrating all advanced features", + before=[(validate_input, Priority.HIGHEST), (log_start, Priority.MEDIUM)], + after=[log_end, global_metrics], + cache={"ttl": 600, "tags": ["analysis"], "condition": lambda r: r.get("success", True)}, + external={"executor": "gpu_cluster", "when": lambda data: len(data) > 100}, + inputs=[ + InputGroup( + "Data Configuration", + Field(name="dataset", type=str, description="Dataset name"), + Field(name="algorithm", type=Choice(["linear", "svm", "neural"]), default="linear") + ), + InputGroup( + "Output Options", + Field(name="save_results", type=bool, default=True), + Field(name="output_format", type=Choice(["json", "csv", "xlsx"]), default="json") + ) + ] +) +def comprehensive_analysis(data: List[float], **config) -> Dict[str, Any]: + """Demonstrate all advanced tool features in one comprehensive example.""" + return { + "data_points": len(data), + "algorithm": config.get("algorithm", "linear"), + "results": {"accuracy": 0.92, "precision": 0.89}, + "config": config, + "success": True + } + + +if __name__ == "__main__": + # Example usage demonstrations + print("๐Ÿ”ง Advanced Tools Examples") + print("=" * 50) + + # Test basic tool with hooks + print("\n1. Basic tool with hooks:") + result = calculate_metrics([1, 2, 3, 4, 5]) + print(f"Result: {result}") + + # Test caching + print("\n2. Caching example:") + print("First call (will be slow):") + weather1 = fetch_weather("New York") + print(f"Result: {weather1}") + + print("Second call (cached, should be fast):") + weather2 = fetch_weather("New York") + print(f"Result: {weather2}") + + # Test error handling + print("\n3. Error handling:") + try: + risky_operation(fail=True) + except Exception as e: + print(f"Error caught: {e}") + + # Test external execution + print("\n4. External execution:") + gpu_result = run_on_gpu("/models/test", [1, 2, 3]) + print(f"GPU Result: {gpu_result}") + + print("\nโœ… All examples completed!") \ No newline at end of file diff --git a/src/praisonai-agents/test_advanced_tools.py b/src/praisonai-agents/test_advanced_tools.py new file mode 100644 index 000000000..0a444c3ad --- /dev/null +++ b/src/praisonai-agents/test_advanced_tools.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 +""" +Test script for advanced tools functionality. +This script validates that the new advanced tools features work correctly. +""" + +import sys +import os + +# Add the package to the path +sys.path.insert(0, os.path.join(os.path.dirname(__file__))) + +def test_imports(): + """Test that all advanced tools components can be imported.""" + try: + from praisonaiagents.tools import ( + tool, cache, external, user_input, + Field, InputGroup, Choice, Range, Pattern, + ToolContext, Hook, CacheConfig, ExternalConfig, Priority, + set_global_hooks, clear_global_hooks, register_external_handler, + invalidate_cache, clear_all_caches, get_cache_stats + ) + print("โœ… All advanced tools imports successful") + return True + except ImportError as e: + print(f"โŒ Import error: {e}") + return False + +def test_basic_tool_decorator(): + """Test basic @tool decorator functionality.""" + try: + from praisonaiagents.tools import tool + + @tool + def simple_tool(x: int) -> int: + """A simple test tool.""" + return x * 2 + + # Test that the tool works + result = simple_tool(5) + assert result == 10, f"Expected 10, got {result}" + + # Test that metadata is attached + assert hasattr(simple_tool, '_tool_metadata'), "Tool metadata not attached" + metadata = simple_tool._tool_metadata + assert metadata['name'] == 'simple_tool', f"Wrong tool name: {metadata['name']}" + + print("โœ… Basic @tool decorator works") + return True + except Exception as e: + print(f"โŒ Basic tool decorator error: {e}") + return False + +def test_hooks(): + """Test pre/post execution hooks.""" + try: + from praisonaiagents.tools import tool, ToolContext + + # Track hook execution + hook_calls = [] + + def before_hook(context: ToolContext): + hook_calls.append(f"before_{context.tool_name}") + + def after_hook(context: ToolContext): + hook_calls.append(f"after_{context.tool_name}") + + @tool(before=before_hook, after=after_hook) + def hooked_tool(x: int) -> int: + """Tool with hooks.""" + return x + 1 + + # Execute the tool + result = hooked_tool(5) + assert result == 6, f"Expected 6, got {result}" + + # Check that hooks were called + assert "before_hooked_tool" in hook_calls, "Before hook not called" + assert "after_hooked_tool" in hook_calls, "After hook not called" + + print("โœ… Hooks functionality works") + return True + except Exception as e: + print(f"โŒ Hooks error: {e}") + return False + +def test_caching(): + """Test caching functionality.""" + try: + from praisonaiagents.tools import tool, cache + import time + + call_count = 0 + + @tool + @cache(ttl=60) # 1 minute cache + def cached_tool(x: int) -> dict: + """Tool with caching.""" + nonlocal call_count + call_count += 1 + return {"value": x * 2, "call_count": call_count} + + # First call + result1 = cached_tool(5) + assert result1["value"] == 10, f"Expected 10, got {result1['value']}" + assert call_count == 1, f"Expected 1 call, got {call_count}" + + # Second call should be cached + result2 = cached_tool(5) + assert result2["value"] == 10, f"Expected 10, got {result2['value']}" + assert call_count == 1, f"Expected 1 call (cached), got {call_count}" + + print("โœ… Caching functionality works") + return True + except Exception as e: + print(f"โŒ Caching error: {e}") + return False + +def test_external_markers(): + """Test external execution markers.""" + try: + from praisonaiagents.tools import tool, external + + @tool + @external + def external_tool(x: int) -> int: + """Tool marked for external execution.""" + return x * 3 + + # For now, external tools execute normally + result = external_tool(4) + assert result == 12, f"Expected 12, got {result}" + + # Check metadata + metadata = external_tool._tool_metadata + assert metadata['external_config'] is not None, "External config not set" + + print("โœ… External execution markers work") + return True + except Exception as e: + print(f"โŒ External markers error: {e}") + return False + +def test_user_input_fields(): + """Test structured user input fields.""" + try: + from praisonaiagents.tools import tool, user_input, Field, Choice + + @tool + @user_input( + Field(name="name", type=str, description="User name"), + Field(name="priority", type=Choice(["low", "high"]), default="low") + ) + def input_tool(name: str, priority: str = "low") -> dict: + """Tool with structured input.""" + return {"name": name, "priority": priority} + + result = input_tool("test", "high") + assert result["name"] == "test", f"Expected 'test', got {result['name']}" + assert result["priority"] == "high", f"Expected 'high', got {result['priority']}" + + # Check metadata + metadata = input_tool._tool_metadata + assert metadata['inputs'] is not None, "Inputs not set" + assert len(metadata['inputs']) == 2, f"Expected 2 inputs, got {len(metadata['inputs'])}" + + print("โœ… User input fields work") + return True + except Exception as e: + print(f"โŒ User input fields error: {e}") + return False + +def test_backward_compatibility(): + """Test that existing tools still work.""" + try: + # Test that we can still import existing tools + from praisonaiagents.tools import TOOL_MAPPINGS + + # Check that tool mappings still exist + assert len(TOOL_MAPPINGS) > 0, "Tool mappings empty" + assert 'internet_search' in TOOL_MAPPINGS, "internet_search tool missing" + + print("โœ… Backward compatibility maintained") + return True + except Exception as e: + print(f"โŒ Backward compatibility error: {e}") + return False + +def test_comprehensive_example(): + """Test a comprehensive tool with multiple features.""" + try: + from praisonaiagents.tools import tool, cache, Field, Choice, ToolContext, Priority + + hook_calls = [] + + def validator(context: ToolContext): + hook_calls.append("validator") + + def logger(context: ToolContext): + hook_calls.append("logger") + + @tool( + name="comprehensive_test", + description="A comprehensive test tool", + before=[(validator, Priority.HIGHEST), (logger, Priority.MEDIUM)], + cache={"ttl": 300, "tags": ["test"]}, + inputs=[ + Field(name="data", type=str), + Field(name="format", type=Choice(["json", "xml"]), default="json") + ] + ) + def comprehensive_tool(data: str, format: str = "json") -> dict: + """Comprehensive test tool.""" + return {"data": data, "format": format, "processed": True} + + result = comprehensive_tool("test_data", "xml") + assert result["data"] == "test_data", "Data not preserved" + assert result["format"] == "xml", "Format not preserved" + assert result["processed"] is True, "Processed flag not set" + + # Check hooks were called + assert "validator" in hook_calls, "Validator hook not called" + assert "logger" in hook_calls, "Logger hook not called" + + print("โœ… Comprehensive example works") + return True + except Exception as e: + print(f"โŒ Comprehensive example error: {e}") + return False + +def run_all_tests(): + """Run all tests and report results.""" + print("๐Ÿงช Testing Advanced Tools Implementation") + print("=" * 50) + + tests = [ + test_imports, + test_backward_compatibility, + test_basic_tool_decorator, + test_hooks, + test_caching, + test_external_markers, + test_user_input_fields, + test_comprehensive_example + ] + + passed = 0 + total = len(tests) + + for test in tests: + try: + if test(): + passed += 1 + except Exception as e: + print(f"โŒ Test {test.__name__} failed with exception: {e}") + + print("\n" + "=" * 50) + print(f"๐Ÿ“Š Test Results: {passed}/{total} passed") + + if passed == total: + print("๐ŸŽ‰ All tests passed! Advanced tools implementation is working correctly.") + return True + else: + print(f"โš ๏ธ {total - passed} tests failed. Implementation needs review.") + return False + +if __name__ == "__main__": + success = run_all_tests() + sys.exit(0 if success else 1) \ No newline at end of file