diff --git a/deploy/.env.example b/deploy/.env.example new file mode 100644 index 000000000..6b0f4de4f --- /dev/null +++ b/deploy/.env.example @@ -0,0 +1,3 @@ +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +AWS_REGION= diff --git a/deploy/README.md b/deploy/README.md new file mode 100644 index 000000000..2e53469ce --- /dev/null +++ b/deploy/README.md @@ -0,0 +1,10 @@ +``` +# First time setup +cd deploy +uv venv +source .venv/bin/activate +uv pip install -e . + +# Subsequent usage +python deploy/models/omniparser/deploy.py start +``` diff --git a/deploy/deploy/models/omniparser/.dockerignore b/deploy/deploy/models/omniparser/.dockerignore new file mode 100644 index 000000000..213bee701 --- /dev/null +++ b/deploy/deploy/models/omniparser/.dockerignore @@ -0,0 +1,20 @@ +__pycache__ +*.pyc +*.pyo +*.pyd +.Python +env +pip-log.txt +pip-delete-this-directory.txt +.tox +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.log +.pytest_cache +.env +.venv +.DS_Store diff --git a/deploy/deploy/models/omniparser/Dockerfile b/deploy/deploy/models/omniparser/Dockerfile new file mode 100644 index 000000000..f14ea7ac8 --- /dev/null +++ b/deploy/deploy/models/omniparser/Dockerfile @@ -0,0 +1,59 @@ +FROM nvidia/cuda:12.3.1-devel-ubuntu22.04 + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + git-lfs \ + wget \ + libgl1 \ + libglib2.0-0 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* \ + && git lfs install + +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh && \ + bash miniconda.sh -b -p /opt/conda && \ + rm miniconda.sh +ENV PATH="/opt/conda/bin:$PATH" + +RUN conda create -n omni python=3.12 && \ + echo "source activate omni" > ~/.bashrc +ENV CONDA_DEFAULT_ENV=omni +ENV PATH="/opt/conda/envs/omni/bin:$PATH" + +WORKDIR /app + +RUN git clone https://github.com/microsoft/OmniParser.git && \ + cd OmniParser && \ + git lfs install && \ + git lfs pull + +WORKDIR /app/OmniParser + +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + pip uninstall -y opencv-python opencv-python-headless && \ + pip install --no-cache-dir opencv-python-headless==4.8.1.78 && \ + pip install -r requirements.txt && \ + pip install huggingface_hub fastapi uvicorn + +# Download V2 weights +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + mkdir -p /app/OmniParser/weights && \ + cd /app/OmniParser && \ + rm -rf weights/icon_detect weights/icon_caption weights/icon_caption_florence && \ + for folder in icon_caption icon_detect; do \ + huggingface-cli download microsoft/OmniParser-v2.0 --local-dir weights --repo-type model --include "$folder/*"; \ + done && \ + mv weights/icon_caption weights/icon_caption_florence + +# Pre-download OCR models during build +RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \ + cd /app/OmniParser && \ + python3 -c "import easyocr; reader = easyocr.Reader(['en']); print('Downloaded EasyOCR model')" && \ + python3 -c "from paddleocr import PaddleOCR; ocr = PaddleOCR(lang='en', use_angle_cls=False, use_gpu=False, show_log=False); print('Downloaded PaddleOCR model')" + +CMD ["python3", "/app/OmniParser/omnitool/omniparserserver/omniparserserver.py", \ + "--som_model_path", "/app/OmniParser/weights/icon_detect/model.pt", \ + "--caption_model_path", "/app/OmniParser/weights/icon_caption_florence", \ + "--device", "cuda", \ + "--BOX_TRESHOLD", "0.05", \ + "--host", "0.0.0.0", \ + "--port", "8000"] diff --git a/deploy/deploy/models/omniparser/client.py b/deploy/deploy/models/omniparser/client.py new file mode 100644 index 000000000..c0cac4f49 --- /dev/null +++ b/deploy/deploy/models/omniparser/client.py @@ -0,0 +1,128 @@ +"""Client module for interacting with the OmniParser server.""" + +import base64 +import fire +import requests + +from loguru import logger +from PIL import Image, ImageDraw + + +def image_to_base64(image_path: str) -> str: + """Convert an image file to base64 string. + + Args: + image_path: Path to the image file + + Returns: + str: Base64 encoded string of the image + """ + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +def plot_results( + original_image_path: str, + som_image_base64: str, + parsed_content_list: list[dict[str, list[float]]], +) -> None: + """Plot parsing results on the original image. + + Args: + original_image_path: Path to the original image + som_image_base64: Base64 encoded SOM image + parsed_content_list: List of parsed content with bounding boxes + """ + # Open original image + image = Image.open(original_image_path) + width, height = image.size + + # Create drawable image + draw = ImageDraw.Draw(image) + + # Draw bounding boxes and labels + for item in parsed_content_list: + # Get normalized coordinates and convert to pixel coordinates + x1, y1, x2, y2 = item["bbox"] + x1 = int(x1 * width) + y1 = int(y1 * height) + x2 = int(x2 * width) + y2 = int(y2 * height) + + label = item["content"] + + # Draw rectangle + draw.rectangle([(x1, y1), (x2, y2)], outline="red", width=2) + + # Draw label background + text_bbox = draw.textbbox((x1, y1), label) + draw.rectangle( + [text_bbox[0] - 2, text_bbox[1] - 2, text_bbox[2] + 2, text_bbox[3] + 2], + fill="white", + ) + + # Draw label text + draw.text((x1, y1), label, fill="red") + + # Show image + image.show() + + +def parse_image( + image_path: str, + server_url: str, +) -> None: + """Parse an image using the OmniParser server. + + Args: + image_path: Path to the image file + server_url: URL of the OmniParser server + """ + # Remove trailing slash from server_url if present + server_url = server_url.rstrip("/") + + # Convert image to base64 + base64_image = image_to_base64(image_path) + + # Prepare request + url = f"{server_url}/parse/" + payload = {"base64_image": base64_image} + + try: + # First, check if the server is available + probe_url = f"{server_url}/probe/" + probe_response = requests.get(probe_url) + probe_response.raise_for_status() + logger.info("Server is available") + + # Make request to API + response = requests.post(url, json=payload) + response.raise_for_status() + + # Parse response + result = response.json() + som_image_base64 = result["som_image_base64"] + parsed_content_list = result["parsed_content_list"] + + # Plot results + plot_results(image_path, som_image_base64, parsed_content_list) + + # Print latency + logger.info(f"API Latency: {result['latency']:.2f} seconds") + + except requests.exceptions.ConnectionError: + logger.error(f"Error: Could not connect to server at {server_url}") + logger.error("Please check if the server is running and the URL is correct") + except requests.exceptions.RequestException as e: + logger.error(f"Error making request to API: {e}") + except Exception as e: + logger.error(f"Error: {e}") + + +def main() -> None: + """Main entry point for the client application.""" + fire.Fire(parse_image) + + +if __name__ == "__main__": + main() diff --git a/deploy/deploy/models/omniparser/deploy.py b/deploy/deploy/models/omniparser/deploy.py new file mode 100644 index 000000000..b951378bb --- /dev/null +++ b/deploy/deploy/models/omniparser/deploy.py @@ -0,0 +1,785 @@ +"""Deployment module for OmniParser on AWS EC2.""" + +import os +import subprocess +import time + +from botocore.exceptions import ClientError +from loguru import logger +from pydantic_settings import BaseSettings +import boto3 +import fire +import paramiko + + +CLEANUP_ON_FAILURE = False + + +class Config(BaseSettings): + """Configuration settings for deployment.""" + + AWS_ACCESS_KEY_ID: str + AWS_SECRET_ACCESS_KEY: str + AWS_REGION: str + + PROJECT_NAME: str = "omniparser" + REPO_URL: str = "https://github.com/microsoft/OmniParser.git" + AWS_EC2_AMI: str = "ami-06835d15c4de57810" + AWS_EC2_DISK_SIZE: int = 128 # GB + AWS_EC2_INSTANCE_TYPE: str = "g4dn.xlarge" # (T4 16GB $0.526/hr x86_64) + AWS_EC2_USER: str = "ubuntu" + PORT: int = 8000 # FastAPI port + COMMAND_TIMEOUT: int = 600 # 10 minutes + + class Config: + """Pydantic configuration class.""" + + env_file = ".env" + env_file_encoding = "utf-8" + + @property + def CONTAINER_NAME(self) -> str: + """Get the container name.""" + return f"{self.PROJECT_NAME}-container" + + @property + def AWS_EC2_KEY_NAME(self) -> str: + """Get the EC2 key pair name.""" + return f"{self.PROJECT_NAME}-key" + + @property + def AWS_EC2_KEY_PATH(self) -> str: + """Get the path to the EC2 key file.""" + return f"./{self.AWS_EC2_KEY_NAME}.pem" + + @property + def AWS_EC2_SECURITY_GROUP(self) -> str: + """Get the EC2 security group name.""" + return f"{self.PROJECT_NAME}-SecurityGroup" + + +config = Config() + + +def create_key_pair( + key_name: str = config.AWS_EC2_KEY_NAME, key_path: str = config.AWS_EC2_KEY_PATH +) -> str | None: + """Create an EC2 key pair. + + Args: + key_name: Name of the key pair + key_path: Path where to save the key file + + Returns: + str | None: Key name if successful, None otherwise + """ + ec2_client = boto3.client("ec2", region_name=config.AWS_REGION) + try: + key_pair = ec2_client.create_key_pair(KeyName=key_name) + private_key = key_pair["KeyMaterial"] + + with open(key_path, "w") as key_file: + key_file.write(private_key) + os.chmod(key_path, 0o400) # Set read-only permissions + + logger.info(f"Key pair {key_name} created and saved to {key_path}") + return key_name + except ClientError as e: + logger.error(f"Error creating key pair: {e}") + return None + + +def get_or_create_security_group_id(ports: list[int] = [22, config.PORT]) -> str | None: + """Get existing security group or create a new one. + + Args: + ports: List of ports to open in the security group + + Returns: + str | None: Security group ID if successful, None otherwise + """ + ec2 = boto3.client("ec2", region_name=config.AWS_REGION) + + ip_permissions = [ + { + "IpProtocol": "tcp", + "FromPort": port, + "ToPort": port, + "IpRanges": [{"CidrIp": "0.0.0.0/0"}], + } + for port in ports + ] + + try: + response = ec2.describe_security_groups( + GroupNames=[config.AWS_EC2_SECURITY_GROUP] + ) + security_group_id = response["SecurityGroups"][0]["GroupId"] + logger.info( + f"Security group '{config.AWS_EC2_SECURITY_GROUP}' already exists: " + f"{security_group_id}" + ) + + for ip_permission in ip_permissions: + try: + ec2.authorize_security_group_ingress( + GroupId=security_group_id, IpPermissions=[ip_permission] + ) + logger.info(f"Added inbound rule for port {ip_permission['FromPort']}") + except ClientError as e: + if e.response["Error"]["Code"] == "InvalidPermission.Duplicate": + logger.info( + f"Rule for port {ip_permission['FromPort']} already exists" + ) + else: + logger.error( + f"Error adding rule for port {ip_permission['FromPort']}: {e}" + ) + + return security_group_id + except ClientError as e: + if e.response["Error"]["Code"] == "InvalidGroup.NotFound": + try: + response = ec2.create_security_group( + GroupName=config.AWS_EC2_SECURITY_GROUP, + Description="Security group for OmniParser deployment", + TagSpecifications=[ + { + "ResourceType": "security-group", + "Tags": [{"Key": "Name", "Value": config.PROJECT_NAME}], + } + ], + ) + security_group_id = response["GroupId"] + logger.info( + f"Created security group '{config.AWS_EC2_SECURITY_GROUP}' " + f"with ID: {security_group_id}" + ) + + ec2.authorize_security_group_ingress( + GroupId=security_group_id, IpPermissions=ip_permissions + ) + logger.info(f"Added inbound rules for ports {ports}") + + return security_group_id + except ClientError as e: + logger.error(f"Error creating security group: {e}") + return None + else: + logger.error(f"Error describing security groups: {e}") + return None + + +def deploy_ec2_instance( + ami: str = config.AWS_EC2_AMI, + instance_type: str = config.AWS_EC2_INSTANCE_TYPE, + project_name: str = config.PROJECT_NAME, + key_name: str = config.AWS_EC2_KEY_NAME, + disk_size: int = config.AWS_EC2_DISK_SIZE, +) -> tuple[str | None, str | None]: + """Deploy a new EC2 instance or return existing one. + + Args: + ami: AMI ID to use for the instance + instance_type: EC2 instance type + project_name: Name tag for the instance + key_name: Name of the key pair to use + disk_size: Size of the root volume in GB + + Returns: + tuple[str | None, str | None]: Instance ID and public IP if successful + """ + ec2 = boto3.resource("ec2") + ec2_client = boto3.client("ec2") + + # Check for existing instances first + instances = ec2.instances.filter( + Filters=[ + {"Name": "tag:Name", "Values": [config.PROJECT_NAME]}, + { + "Name": "instance-state-name", + "Values": ["running", "pending", "stopped"], + }, + ] + ) + + existing_instance = None + for instance in instances: + existing_instance = instance + if instance.state["Name"] == "running": + logger.info( + f"Instance already running: ID - {instance.id}, " + f"IP - {instance.public_ip_address}" + ) + break + elif instance.state["Name"] == "stopped": + logger.info(f"Starting existing stopped instance: ID - {instance.id}") + ec2_client.start_instances(InstanceIds=[instance.id]) + instance.wait_until_running() + instance.reload() + logger.info( + f"Instance started: ID - {instance.id}, " + f"IP - {instance.public_ip_address}" + ) + break + + # If we found an existing instance, ensure we have its key + if existing_instance: + if not os.path.exists(config.AWS_EC2_KEY_PATH): + logger.warning( + f"Key file {config.AWS_EC2_KEY_PATH} not found for existing instance." + ) + logger.warning( + "You'll need to use the original key file to connect to this instance." + ) + logger.warning( + "Consider terminating the instance with 'deploy.py stop' and starting " + "fresh." + ) + return None, None + return existing_instance.id, existing_instance.public_ip_address + + # No existing instance found, create new one with new key pair + security_group_id = get_or_create_security_group_id() + if not security_group_id: + logger.error( + "Unable to retrieve security group ID. Instance deployment aborted." + ) + return None, None + + # Create new key pair + try: + if os.path.exists(config.AWS_EC2_KEY_PATH): + logger.info(f"Removing existing key file {config.AWS_EC2_KEY_PATH}") + os.remove(config.AWS_EC2_KEY_PATH) + + try: + ec2_client.delete_key_pair(KeyName=key_name) + logger.info(f"Deleted existing key pair {key_name}") + except ClientError: + pass # Key pair doesn't exist, which is fine + + if not create_key_pair(key_name): + logger.error("Failed to create key pair") + return None, None + except Exception as e: + logger.error(f"Error managing key pair: {e}") + return None, None + + # Create new instance + ebs_config = { + "DeviceName": "/dev/sda1", + "Ebs": { + "VolumeSize": disk_size, + "VolumeType": "gp3", + "DeleteOnTermination": True, + }, + } + + new_instance = ec2.create_instances( + ImageId=ami, + MinCount=1, + MaxCount=1, + InstanceType=instance_type, + KeyName=key_name, + SecurityGroupIds=[security_group_id], + BlockDeviceMappings=[ebs_config], + TagSpecifications=[ + { + "ResourceType": "instance", + "Tags": [{"Key": "Name", "Value": project_name}], + }, + ], + )[0] + + new_instance.wait_until_running() + new_instance.reload() + logger.info( + f"New instance created: ID - {new_instance.id}, " + f"IP - {new_instance.public_ip_address}" + ) + return new_instance.id, new_instance.public_ip_address + + +def configure_ec2_instance( + instance_id: str | None = None, + instance_ip: str | None = None, + max_ssh_retries: int = 20, + ssh_retry_delay: int = 20, + max_cmd_retries: int = 20, + cmd_retry_delay: int = 30, +) -> tuple[str | None, str | None]: + """Configure an EC2 instance with necessary dependencies and Docker setup. + + This function either configures an existing EC2 instance specified by instance_id + and instance_ip, or deploys and configures a new instance. It installs Docker and + other required dependencies, and sets up the environment for running containers. + + Args: + instance_id: Optional ID of an existing EC2 instance to configure. + If None, a new instance will be deployed. + instance_ip: Optional IP address of an existing EC2 instance. + Required if instance_id is provided. + max_ssh_retries: Maximum number of SSH connection attempts. + Defaults to 20 attempts. + ssh_retry_delay: Delay in seconds between SSH connection attempts. + Defaults to 20 seconds. + max_cmd_retries: Maximum number of command execution retries. + Defaults to 20 attempts. + cmd_retry_delay: Delay in seconds between command execution retries. + Defaults to 30 seconds. + + Returns: + tuple[str | None, str | None]: A tuple containing: + - The instance ID (str) or None if configuration failed + - The instance's public IP address (str) or None if configuration failed + + Raises: + RuntimeError: If command execution fails + paramiko.SSHException: If SSH connection fails + Exception: For other unexpected errors during configuration + """ + if not instance_id: + ec2_instance_id, ec2_instance_ip = deploy_ec2_instance() + else: + ec2_instance_id = instance_id + ec2_instance_ip = instance_ip + + key = paramiko.RSAKey.from_private_key_file(config.AWS_EC2_KEY_PATH) + ssh_client = paramiko.SSHClient() + ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + ssh_retries = 0 + while ssh_retries < max_ssh_retries: + try: + ssh_client.connect( + hostname=ec2_instance_ip, username=config.AWS_EC2_USER, pkey=key + ) + break + except Exception as e: + ssh_retries += 1 + logger.error(f"SSH connection attempt {ssh_retries} failed: {e}") + if ssh_retries < max_ssh_retries: + logger.info(f"Retrying SSH connection in {ssh_retry_delay} seconds...") + time.sleep(ssh_retry_delay) + else: + logger.error("Maximum SSH connection attempts reached. Aborting.") + return None, None + + commands = [ + "sudo apt-get update", + "sudo apt-get install -y ca-certificates curl gnupg", + "sudo install -m 0755 -d /etc/apt/keyrings", + ( + "curl -fsSL https://download.docker.com/linux/ubuntu/gpg | " + "sudo dd of=/etc/apt/keyrings/docker.gpg" + ), + "sudo chmod a+r /etc/apt/keyrings/docker.gpg", + ( + 'echo "deb [arch="$(dpkg --print-architecture)" ' + "signed-by=/etc/apt/keyrings/docker.gpg] " + "https://download.docker.com/linux/ubuntu " + '"$(. /etc/os-release && echo "$VERSION_CODENAME")" stable" | ' + "sudo tee /etc/apt/sources.list.d/docker.list > /dev/null" + ), + "sudo apt-get update", + ( + "sudo apt-get install -y docker-ce docker-ce-cli containerd.io " + "docker-buildx-plugin docker-compose-plugin" + ), + "sudo systemctl start docker", + "sudo systemctl enable docker", + "sudo usermod -a -G docker ${USER}", + "sudo docker system prune -af --volumes", + f"sudo docker rm -f {config.PROJECT_NAME}-container || true", + ] + + for command in commands: + logger.info(f"Executing command: {command}") + cmd_retries = 0 + while cmd_retries < max_cmd_retries: + stdin, stdout, stderr = ssh_client.exec_command(command) + exit_status = stdout.channel.recv_exit_status() + + if exit_status == 0: + logger.info("Command executed successfully") + break + else: + error_message = stderr.read() + if "Could not get lock" in str(error_message): + cmd_retries += 1 + logger.warning( + f"dpkg is locked, retrying in {cmd_retry_delay} seconds... " + f"Attempt {cmd_retries}/{max_cmd_retries}" + ) + time.sleep(cmd_retry_delay) + else: + logger.error( + f"Error in command: {command}, Exit Status: {exit_status}, " + f"Error: {error_message}" + ) + break + + ssh_client.close() + return ec2_instance_id, ec2_instance_ip + + +def execute_command(ssh_client: paramiko.SSHClient, command: str) -> None: + """Execute a command and handle its output safely.""" + logger.info(f"Executing: {command}") + stdin, stdout, stderr = ssh_client.exec_command( + command, + timeout=config.COMMAND_TIMEOUT, + # get_pty=True + ) + + # Stream output in real-time + while not stdout.channel.exit_status_ready(): + if stdout.channel.recv_ready(): + try: + line = stdout.channel.recv(1024).decode("utf-8", errors="replace") + if line.strip(): # Only log non-empty lines + logger.info(line.strip()) + except Exception as e: + logger.warning(f"Error decoding stdout: {e}") + + if stdout.channel.recv_stderr_ready(): + try: + line = stdout.channel.recv_stderr(1024).decode( + "utf-8", errors="replace" + ) + if line.strip(): # Only log non-empty lines + logger.error(line.strip()) + except Exception as e: + logger.warning(f"Error decoding stderr: {e}") + + exit_status = stdout.channel.recv_exit_status() + + # Capture any remaining output + try: + remaining_stdout = stdout.read().decode("utf-8", errors="replace") + if remaining_stdout.strip(): + logger.info(remaining_stdout.strip()) + except Exception as e: + logger.warning(f"Error decoding remaining stdout: {e}") + + try: + remaining_stderr = stderr.read().decode("utf-8", errors="replace") + if remaining_stderr.strip(): + logger.error(remaining_stderr.strip()) + except Exception as e: + logger.warning(f"Error decoding remaining stderr: {e}") + + if exit_status != 0: + error_msg = f"Command failed with exit status {exit_status}: {command}" + logger.error(error_msg) + raise RuntimeError(error_msg) + + logger.info(f"Successfully executed: {command}") + + +class Deploy: + """Class handling deployment operations for OmniParser.""" + + @staticmethod + def start() -> None: + """Start a new deployment of OmniParser on EC2.""" + try: + instance_id, instance_ip = configure_ec2_instance() + assert instance_ip, f"invalid {instance_ip=}" + + # Trigger driver installation via login shell + Deploy.ssh(non_interactive=True) + + # Get the directory containing deploy.py + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Define files to copy + files_to_copy = { + "Dockerfile": os.path.join(current_dir, "Dockerfile"), + ".dockerignore": os.path.join(current_dir, ".dockerignore"), + } + + # Copy files to instance + for filename, filepath in files_to_copy.items(): + if os.path.exists(filepath): + logger.info(f"Copying {filename} to instance...") + subprocess.run( + [ + "scp", + "-i", + config.AWS_EC2_KEY_PATH, + "-o", + "StrictHostKeyChecking=no", + filepath, + f"{config.AWS_EC2_USER}@{instance_ip}:~/{filename}", + ], + check=True, + ) + else: + logger.warning(f"File not found: {filepath}") + + # Connect to instance and execute commands + key = paramiko.RSAKey.from_private_key_file(config.AWS_EC2_KEY_PATH) + ssh_client = paramiko.SSHClient() + ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + try: + logger.info(f"Connecting to {instance_ip}...") + ssh_client.connect( + hostname=instance_ip, + username=config.AWS_EC2_USER, + pkey=key, + timeout=30, + ) + + setup_commands = [ + "rm -rf OmniParser", # Clean up any existing repo + f"git clone {config.REPO_URL}", + "cp Dockerfile .dockerignore OmniParser/", + ] + + # Execute setup commands + for command in setup_commands: + logger.info(f"Executing setup command: {command}") + execute_command(ssh_client, command) + + # Build and run Docker container + docker_commands = [ + # Remove any existing container + "sudo docker rm -f {config.CONTAINER_NAME} || true", + # Remove any existing image + "sudo docker rmi {config.PROJECT_NAME} || true", + # Build new image + ( + "cd OmniParser && sudo docker build --progress=plain " + "-t {config.PROJECT_NAME} ." + ), + # Run new container + ( + "sudo docker run -d -p 8000:8000 --gpus all --name " + "{config.CONTAINER_NAME} {config.PROJECT_NAME}" + ), + ] + + # Execute Docker commands + for command in docker_commands: + logger.info(f"Executing Docker command: {command}") + execute_command(ssh_client, command) + + # Wait for container to start and check its logs + logger.info("Waiting for container to start...") + time.sleep(10) # Give container time to start + execute_command(ssh_client, "docker logs {config.CONTAINER_NAME}") + + # Wait for server to become responsive + logger.info("Waiting for server to become responsive...") + max_retries = 30 + retry_delay = 10 + server_ready = False + + for attempt in range(max_retries): + try: + # Check if server is responding + check_command = f"curl -s http://localhost:{config.PORT}/probe/" + execute_command(ssh_client, check_command) + server_ready = True + break + except Exception as e: + logger.warning( + f"Server not ready (attempt {attempt + 1}/{max_retries}): " + f"{e}" + ) + if attempt < max_retries - 1: + logger.info( + f"Waiting {retry_delay} seconds before next attempt..." + ) + time.sleep(retry_delay) + + if not server_ready: + raise RuntimeError("Server failed to start properly") + + # Final status check + execute_command(ssh_client, "docker ps | grep {config.CONTAINER_NAME}") + + server_url = f"http://{instance_ip}:{config.PORT}" + logger.info(f"Deployment complete. Server running at: {server_url}") + + # Verify server is accessible from outside + try: + import requests + + response = requests.get(f"{server_url}/probe/", timeout=10) + if response.status_code == 200: + logger.info("Server is accessible from outside!") + else: + logger.warning( + f"Server responded with status code: {response.status_code}" + ) + except Exception as e: + logger.warning(f"Could not verify external access: {e}") + + except Exception as e: + logger.error(f"Error during deployment: {e}") + # Get container logs for debugging + try: + execute_command(ssh_client, "docker logs {config.CONTAINER_NAME}") + except Exception as exc: + logger.warning(f"{exc=}") + pass + raise + + finally: + ssh_client.close() + + except Exception as e: + logger.error(f"Deployment failed: {e}") + if CLEANUP_ON_FAILURE: + # Attempt cleanup on failure + try: + Deploy.stop() + except Exception as cleanup_error: + logger.error(f"Cleanup after failure also failed: {cleanup_error}") + raise + + logger.info("Deployment completed successfully!") + + @staticmethod + def status() -> None: + """Check the status of deployed instances.""" + ec2 = boto3.resource("ec2") + instances = ec2.instances.filter( + Filters=[{"Name": "tag:Name", "Values": [config.PROJECT_NAME]}] + ) + + for instance in instances: + public_ip = instance.public_ip_address + if public_ip: + server_url = f"http://{public_ip}:{config.PORT}" + logger.info( + f"Instance ID: {instance.id}, State: {instance.state['Name']}, " + f"URL: {server_url}" + ) + else: + logger.info( + f"Instance ID: {instance.id}, State: {instance.state['Name']}, " + f"URL: Not available (no public IP)" + ) + + @staticmethod + def ssh(non_interactive: bool = False) -> None: + """SSH into the running instance. + + Args: + non_interactive: If True, run in non-interactive mode + """ + # Get instance IP + ec2 = boto3.resource("ec2") + instances = ec2.instances.filter( + Filters=[ + {"Name": "tag:Name", "Values": [config.PROJECT_NAME]}, + {"Name": "instance-state-name", "Values": ["running"]}, + ] + ) + + instance = next(iter(instances), None) + if not instance: + logger.error("No running instance found") + return + + ip = instance.public_ip_address + if not ip: + logger.error("Instance has no public IP") + return + + # Check if key file exists + if not os.path.exists(config.AWS_EC2_KEY_PATH): + logger.error(f"Key file not found: {config.AWS_EC2_KEY_PATH}") + return + + if non_interactive: + # Simulate full login by forcing all initialization scripts + ssh_command = [ + "ssh", + "-o", + "StrictHostKeyChecking=no", # Automatically accept new host keys + "-o", + "UserKnownHostsFile=/dev/null", # Prevent writing to known_hosts + "-i", + config.AWS_EC2_KEY_PATH, + f"{config.AWS_EC2_USER}@{ip}", + "-t", # Allocate a pseudo-terminal + "-tt", # Force pseudo-terminal allocation + "bash --login -c 'exit'", # Force full login shell and exit immediately + ] + else: + # Build and execute SSH command + ssh_command = ( + f"ssh -i {config.AWS_EC2_KEY_PATH} -o StrictHostKeyChecking=no " + f"{config.AWS_EC2_USER}@{ip}" + ) + logger.info(f"Connecting with: {ssh_command}") + os.system(ssh_command) + return + + # Execute the SSH command for non-interactive mode + try: + subprocess.run(ssh_command, check=True) + except subprocess.CalledProcessError as e: + logger.error(f"SSH connection failed: {e}") + + @staticmethod + def stop( + project_name: str = config.PROJECT_NAME, + security_group_name: str = config.AWS_EC2_SECURITY_GROUP, + ) -> None: + """Terminates the EC2 instance and deletes the associated security group. + + Args: + project_name (str): The project name used to tag the instance. + Defaults to config.PROJECT_NAME. + security_group_name (str): The name of the security group to delete. + Defaults to config.AWS_EC2_SECURITY_GROUP. + """ + ec2_resource = boto3.resource("ec2") + ec2_client = boto3.client("ec2") + + # Terminate EC2 instances + instances = ec2_resource.instances.filter( + Filters=[ + {"Name": "tag:Name", "Values": [project_name]}, + { + "Name": "instance-state-name", + "Values": [ + "pending", + "running", + "shutting-down", + "stopped", + "stopping", + ], + }, + ] + ) + + for instance in instances: + logger.info(f"Terminating instance: ID - {instance.id}") + instance.terminate() + instance.wait_until_terminated() + logger.info(f"Instance {instance.id} terminated successfully.") + + # Delete security group + try: + ec2_client.delete_security_group(GroupName=security_group_name) + logger.info(f"Deleted security group: {security_group_name}") + except ClientError as e: + if e.response["Error"]["Code"] == "InvalidGroup.NotFound": + logger.info( + f"Security group {security_group_name} does not exist or already " + "deleted." + ) + else: + logger.error(f"Error deleting security group: {e}") + + +if __name__ == "__main__": + fire.Fire(Deploy) diff --git a/deploy/pyproject.toml b/deploy/pyproject.toml new file mode 100644 index 000000000..835b62424 --- /dev/null +++ b/deploy/pyproject.toml @@ -0,0 +1,22 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "deploy" +version = "0.1.0" +authors = [ + { name="Richard Abrich", email="richard@openadapt.ai" }, +] +description = "Deployment tools for OpenAdapt models" +requires-python = ">=3.10" +dependencies = [ + "boto3>=1.36.22", + "fire>=0.7.0", + "loguru>=0.7.0", + "paramiko>=3.5.1", + "pillow>=11.1.0", + "pydantic>=2.10.6", + "pydantic-settings>=2.7.1", + "requests>=2.32.3", +] diff --git a/openadapt/adapters/omniparser.py b/openadapt/adapters/omniparser.py new file mode 100644 index 000000000..0cd3e4f94 --- /dev/null +++ b/openadapt/adapters/omniparser.py @@ -0,0 +1,165 @@ +"""Adapter for interacting with the OmniParser server. + +This module provides a client for the OmniParser API deployed on AWS. +""" + +import base64 +import io +from typing import Dict, List, Any, Optional + +import requests +from PIL import Image + +from openadapt.custom_logger import logger + + +class OmniParserClient: + """Client for the OmniParser API.""" + + def __init__(self, server_url: str): + """Initialize the OmniParser client. + + Args: + server_url: URL of the OmniParser server + """ + self.server_url = server_url.rstrip("/") # Remove trailing slash if present + + def check_server_available(self) -> bool: + """Check if the OmniParser server is available. + + Returns: + bool: True if server is available, False otherwise + """ + try: + probe_url = f"{self.server_url}/probe/" + response = requests.get(probe_url, timeout=5) + response.raise_for_status() + logger.info("OmniParser server is available") + return True + except requests.exceptions.RequestException as e: + logger.error(f"OmniParser server not available: {e}") + return False + + def image_to_base64(self, image: Image.Image) -> str: + """Convert a PIL Image to base64 string. + + Args: + image: PIL Image to convert + + Returns: + str: Base64 encoded string of the image + """ + img_byte_arr = io.BytesIO() + image.save(img_byte_arr, format='PNG') + return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8") + + def parse_image(self, image: Image.Image) -> Dict[str, Any]: + """Parse an image using the OmniParser service. + + Args: + image: PIL Image to parse + + Returns: + Dict[str, Any]: Parsed results including UI elements + """ + if not self.check_server_available(): + return {"error": "Server not available", "parsed_content_list": []} + + # Convert image to base64 + base64_image = self.image_to_base64(image) + + # Prepare request + url = f"{self.server_url}/parse/" + payload = {"base64_image": base64_image} + + try: + # Make request to API + response = requests.post(url, json=payload, timeout=30) + response.raise_for_status() + + # Parse response + result = response.json() + logger.info(f"OmniParser latency: {result.get('latency', 0):.2f} seconds") + return result + except requests.exceptions.RequestException as e: + logger.error(f"Error making request to OmniParser API: {e}") + return {"error": str(e), "parsed_content_list": []} + except Exception as e: + logger.error(f"Error parsing image with OmniParser: {e}") + return {"error": str(e), "parsed_content_list": []} + + +class OmniParserProvider: + """Provider for OmniParser services.""" + + def __init__(self, server_url: Optional[str] = None): + """Initialize OmniParser provider. + + Args: + server_url: URL of the OmniParser server (optional) + """ + self.server_url = server_url or "http://localhost:8000" + self.client = OmniParserClient(self.server_url) + + def is_available(self) -> bool: + """Check if the OmniParser service is available. + + Returns: + bool: True if service is available, False otherwise + """ + return self.client.check_server_available() + + def status(self) -> Dict[str, Any]: + """Check the status of the OmniParser service. + + Returns: + Dict[str, Any]: Status information + """ + is_available = self.is_available() + return { + "services": [ + { + "name": "omniparser", + "status": "running" if is_available else "stopped", + "url": self.server_url + } + ], + "is_available": is_available + } + + def deploy(self) -> bool: + """Deploy the OmniParser service if not already running. + + Returns: + bool: True if successfully deployed or already running, False otherwise + """ + # Check if already running + if self.status()["is_available"]: + logger.info("OmniParser service is already running") + return True + + # Try to deploy using the deployment script + try: + from deploy.deploy.models.omniparser.deploy import Deploy + logger.info("Deploying OmniParser service...") + Deploy.start() + return self.status()["is_available"] + except Exception as e: + logger.error(f"Failed to deploy OmniParser service: {e}") + return False + + def parse_screenshot(self, image_data: bytes) -> Dict[str, Any]: + """Parse a screenshot using OmniParser. + + Args: + image_data: Raw image data in bytes + + Returns: + Dict[str, Any]: Parsed content with UI elements + """ + try: + image = Image.open(io.BytesIO(image_data)) + return self.client.parse_image(image) + except Exception as e: + logger.error(f"Error processing image data: {e}") + return {"error": str(e), "parsed_content_list": []} \ No newline at end of file diff --git a/openadapt/mcp/__init__.py b/openadapt/mcp/__init__.py new file mode 100644 index 000000000..247248fe5 --- /dev/null +++ b/openadapt/mcp/__init__.py @@ -0,0 +1 @@ +"""Model Control Protocol (MCP) implementation for OpenAdapt.""" \ No newline at end of file diff --git a/openadapt/mcp/server.py b/openadapt/mcp/server.py new file mode 100644 index 000000000..99f580088 --- /dev/null +++ b/openadapt/mcp/server.py @@ -0,0 +1,327 @@ +"""MCP server implementation for OmniMCP. + +This module implements a Model Control Protocol server that exposes +UI automation capabilities to Claude through a standardized interface. + +Usage: + # Import and create server instance + from openadapt.mcp.server import create_omnimcp_server + from openadapt.omnimcp import OmniMCP + + # Create OmniMCP instance + omnimcp = OmniMCP() + + # Create and run server + server = create_omnimcp_server(omnimcp) + server.run() +""" + +import datetime +import io +import json +import os +from typing import Any, Dict, List, Optional + +from mcp.server.fastmcp import FastMCP + +from openadapt.custom_logger import logger + + +def create_debug_directory() -> str: + """Create a timestamped directory for debug outputs. + + Returns: + str: Path to debug directory + """ + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + debug_dir = os.path.join( + os.path.expanduser("~"), + "omnimcp_debug", + f"session_{timestamp}" + ) + os.makedirs(debug_dir, exist_ok=True) + logger.info(f"Created debug directory: {debug_dir}") + return debug_dir + + +def create_omnimcp_server(omnimcp_instance) -> FastMCP: + """Create an MCP server for the given OmniMCP instance. + + Args: + omnimcp_instance: An instance of the OmniMCP class + + Returns: + FastMCP: The MCP server instance + """ + # Initialize FastMCP server + server = FastMCP("omnimcp") + + # Create debug directory + debug_dir = create_debug_directory() + + @server.tool() + async def get_screen_state() -> Dict[str, Any]: + """Get the current state of the screen with UI elements. + + Returns a structured representation of all UI elements detected on screen, + including their positions, descriptions, and other metadata. + """ + # Update visual state + omnimcp_instance.update_visual_state() + + # Save screenshot with timestamp for debugging + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + debug_path = os.path.join(debug_dir, f"screen_state_{timestamp}.png") + omnimcp_instance.save_visual_debug(debug_path) + + # Get structured description and parse into JSON + mcp_description = omnimcp_instance.visual_state.to_mcp_description( + omnimcp_instance.use_normalized_coordinates + ) + + return json.loads(mcp_description) + + @server.tool() + async def find_ui_element(descriptor: str, partial_match: bool = True) -> Dict[str, Any]: + """Find a UI element by its descriptor. + + Args: + descriptor: Descriptive text to search for in element content + partial_match: Whether to allow partial matching + + Returns: + Information about the matched element or error if not found + """ + # Update visual state + omnimcp_instance.update_visual_state() + + # Find element + element = omnimcp_instance.visual_state.find_element_by_content( + descriptor, + partial_match + ) + + if not element: + return { + "found": False, + "error": f"No UI element matching '{descriptor}' was found", + "possible_elements": [ + el.content for el in omnimcp_instance.visual_state.elements[:10] + ] + } + + # Return element details + return { + "found": True, + "content": element.content, + "type": element.type, + "confidence": element.confidence, + "bounds": { + "x1": element.x1, + "y1": element.y1, + "x2": element.x2, + "y2": element.y2, + "width": element.width, + "height": element.height + }, + "center": { + "x": element.center_x, + "y": element.center_y + }, + "normalized": { + "bounds": element.bbox, + "center": { + "x": element.normalized_center_x, + "y": element.normalized_center_y + } + } + } + + @server.tool() + async def click_element( + descriptor: str, + button: str = "left", + partial_match: bool = True + ) -> Dict[str, Any]: + """Click on a UI element by its descriptor. + + Args: + descriptor: Descriptive text to identify the element + button: Mouse button to use (left, right, middle) + partial_match: Whether to allow partial matching + + Returns: + Result of the click operation + """ + # Find and click the element + success = omnimcp_instance.click_element(descriptor, button, partial_match) + + if success: + # Save debug screenshot after clicking + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + debug_path = os.path.join(debug_dir, f"click_{descriptor}_{timestamp}.png") + omnimcp_instance.save_visual_debug(debug_path) + + return { + "success": True, + "message": f"Successfully clicked element: {descriptor}" + } + else: + return { + "success": False, + "message": f"Failed to find element: {descriptor}", + "possible_elements": [ + el.content for el in omnimcp_instance.visual_state.elements[:10] + ] + } + + @server.tool() + async def click_coordinates( + x: float, + y: float, + button: str = "left" + ) -> Dict[str, Any]: + """Click at specific coordinates on the screen. + + Args: + x: X coordinate (absolute or normalized based on settings) + y: Y coordinate (absolute or normalized based on settings) + button: Mouse button to use (left, right, middle) + + Returns: + Result of the click operation + """ + try: + # Perform click + omnimcp_instance.click(x, y, button) + + # Save debug screenshot after clicking + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + debug_path = os.path.join(debug_dir, f"click_coords_{x}_{y}_{timestamp}.png") + omnimcp_instance.save_visual_debug(debug_path) + + # Determine coordinate format for message + format_type = "normalized" if omnimcp_instance.use_normalized_coordinates else "absolute" + + return { + "success": True, + "message": f"Successfully clicked at {format_type} coordinates ({x}, {y})" + } + except Exception as e: + return { + "success": False, + "message": f"Failed to click: {str(e)}" + } + + @server.tool() + async def type_text(text: str) -> Dict[str, Any]: + """Type text using the keyboard. + + Args: + text: Text to type + + Returns: + Result of the typing operation + """ + try: + omnimcp_instance.type_text(text) + return { + "success": True, + "message": f"Successfully typed: {text}" + } + except Exception as e: + return { + "success": False, + "message": f"Failed to type text: {str(e)}" + } + + @server.tool() + async def press_key(key: str) -> Dict[str, Any]: + """Press a single key on the keyboard. + + Args: + key: Key to press (e.g., enter, tab, escape) + + Returns: + Result of the key press operation + """ + try: + omnimcp_instance.press_key(key) + return { + "success": True, + "message": f"Successfully pressed key: {key}" + } + except Exception as e: + return { + "success": False, + "message": f"Failed to press key: {str(e)}" + } + + @server.tool() + async def list_ui_elements() -> List[Dict[str, Any]]: + """List all detected UI elements on the current screen. + + Returns: + List of all UI elements with basic information + """ + # Update visual state + omnimcp_instance.update_visual_state() + + # Extract basic info for each element + elements = [] + for element in omnimcp_instance.visual_state.elements: + elements.append({ + "content": element.content, + "type": element.type, + "confidence": element.confidence, + "center": { + "x": element.center_x, + "y": element.center_y + }, + "dimensions": { + "width": element.width, + "height": element.height + } + }) + + return elements + + @server.tool() + async def save_debug_screenshot(description: str = "debug") -> Dict[str, Any]: + """Save a debug screenshot with an optional description. + + The description is used to name the screenshot file, making it easier to identify + the purpose of the screenshot (e.g., "before_clicking_submit_button"). + + Args: + description: Description to include in the filename + + Returns: + Result of the save operation + """ + try: + # Create sanitized description for filename + safe_description = "".join(c if c.isalnum() else "_" for c in description) + + # Generate timestamped filename + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + output_path = os.path.join( + debug_dir, + f"{safe_description}_{timestamp}.png" + ) + + # Save the debug visualization + omnimcp_instance.save_visual_debug(output_path) + + return { + "success": True, + "message": f"Debug screenshot saved to {output_path}", + "path": output_path + } + except Exception as e: + return { + "success": False, + "message": f"Failed to save debug screenshot: {str(e)}" + } + + return server \ No newline at end of file diff --git a/openadapt/omnimcp.py b/openadapt/omnimcp.py new file mode 100644 index 000000000..f3ef9890a --- /dev/null +++ b/openadapt/omnimcp.py @@ -0,0 +1,932 @@ +"""OmniMCP: Model Context Protocol implementation with OmniParser. + +This module enables Claude to understand screen content via OmniParser and +take actions through keyboard and mouse primitives based on natural language requests. + +Usage: + # Basic usage with MCP server + from openadapt.omnimcp import OmniMCP + from openadapt.mcp.server import create_omnimcp_server + + # Create OmniMCP instance + omnimcp = OmniMCP() + + # Create and run MCP server + server = create_omnimcp_server(omnimcp) + server.run() + + # Alternatively, run interactively (no MCP) + omnimcp = OmniMCP() + omnimcp.run_interactive() +""" + +import asyncio +import base64 +import datetime +import io +import json +import os +import time +from typing import Dict, List, Any, Optional, Tuple, Union, Callable + +from PIL import Image, ImageDraw +import fire +from pynput import keyboard, mouse + +from openadapt import utils +from openadapt.adapters.omniparser import OmniParserProvider +from openadapt.config import config +from openadapt.custom_logger import logger +from openadapt.drivers import anthropic + + +class ScreenElement: + """Represents a UI element on the screen with bounding box and description.""" + + def __init__(self, element_data: Dict[str, Any]): + """Initialize from OmniParser element data. + + Args: + element_data: Element data from OmniParser + """ + self.content = element_data.get("content", "") + self.bbox = element_data.get("bbox", [0, 0, 0, 0]) # Normalized coordinates + self.confidence = element_data.get("confidence", 0.0) + self.type = element_data.get("type", "unknown") + self.screen_width = 0 + self.screen_height = 0 + + def set_screen_dimensions(self, width: int, height: int): + """Set screen dimensions for coordinate calculations. + + Args: + width: Screen width in pixels + height: Screen height in pixels + """ + self.screen_width = width + self.screen_height = height + + @property + def x1(self) -> int: + """Get left coordinate in pixels.""" + return int(self.bbox[0] * self.screen_width) + + @property + def y1(self) -> int: + """Get top coordinate in pixels.""" + return int(self.bbox[1] * self.screen_height) + + @property + def x2(self) -> int: + """Get right coordinate in pixels.""" + return int(self.bbox[2] * self.screen_width) + + @property + def y2(self) -> int: + """Get bottom coordinate in pixels.""" + return int(self.bbox[3] * self.screen_height) + + @property + def center_x(self) -> int: + """Get center x coordinate in pixels.""" + return (self.x1 + self.x2) // 2 + + @property + def center_y(self) -> int: + """Get center y coordinate in pixels.""" + return (self.y1 + self.y2) // 2 + + @property + def width(self) -> int: + """Get width in pixels.""" + return self.x2 - self.x1 + + @property + def height(self) -> int: + """Get height in pixels.""" + return self.y2 - self.y1 + + @property + def normalized_center_x(self) -> float: + """Get normalized center x coordinate (0-1).""" + if self.screen_width == 0: + return 0.5 + return (self.x1 + self.x2) / (2 * self.screen_width) + + @property + def normalized_center_y(self) -> float: + """Get normalized center y coordinate (0-1).""" + if self.screen_height == 0: + return 0.5 + return (self.y1 + self.y2) / (2 * self.screen_height) + + def __str__(self) -> str: + """String representation with content and coordinates.""" + return f"{self.content} at ({self.x1},{self.y1},{self.x2},{self.y2})" + + +class VisualState: + """Represents the current visual state of the screen with UI elements.""" + + def __init__(self): + """Initialize empty visual state.""" + self.elements: List[ScreenElement] = [] + self.screenshot: Optional[Image.Image] = None + self.timestamp: float = time.time() + + def update_from_omniparser(self, omniparser_result: Dict[str, Any], screenshot: Image.Image): + """Update visual state from OmniParser result. + + Args: + omniparser_result: Result from OmniParser + screenshot: Screenshot image + """ + self.screenshot = screenshot + self.timestamp = time.time() + + # Extract parsed content + parsed_content = omniparser_result.get("parsed_content_list", []) + + # Create screen elements + self.elements = [] + for content in parsed_content: + element = ScreenElement(content) + element.set_screen_dimensions(screenshot.width, screenshot.height) + self.elements.append(element) + + def find_element_by_content(self, content: str, partial_match: bool = True) -> Optional[ScreenElement]: + """Find element by content text. + + Args: + content: Text to search for + partial_match: If True, match substrings + + Returns: + ScreenElement if found, None otherwise + """ + for element in self.elements: + if partial_match and content.lower() in element.content.lower(): + return element + elif element.content.lower() == content.lower(): + return element + return None + + def find_element_by_position(self, x: int, y: int) -> Optional[ScreenElement]: + """Find element at position. + + Args: + x: X coordinate + y: Y coordinate + + Returns: + ScreenElement if found, None otherwise + """ + for element in self.elements: + if element.x1 <= x <= element.x2 and element.y1 <= y <= element.y2: + return element + return None + + def to_mcp_description(self, use_normalized_coordinates: bool = False) -> str: + """Convert visual state to MCP description format. + + Args: + use_normalized_coordinates: If True, use normalized (0-1) coordinates + + Returns: + str: JSON string with structured description + """ + ui_elements = [] + for element in self.elements: + if use_normalized_coordinates: + ui_elements.append({ + "type": element.type, + "text": element.content, + "bounds": { + "x": element.bbox[0], + "y": element.bbox[1], + "width": element.bbox[2] - element.bbox[0], + "height": element.bbox[3] - element.bbox[1] + }, + "center": { + "x": element.normalized_center_x, + "y": element.normalized_center_y + }, + "confidence": element.confidence + }) + else: + ui_elements.append({ + "type": element.type, + "text": element.content, + "bounds": { + "x": element.x1, + "y": element.y1, + "width": element.width, + "height": element.height + }, + "center": { + "x": element.center_x, + "y": element.center_y + }, + "confidence": element.confidence + }) + + visual_state = { + "ui_elements": ui_elements, + "screenshot_timestamp": self.timestamp, + "screen_width": self.screenshot.width if self.screenshot else 0, + "screen_height": self.screenshot.height if self.screenshot else 0, + "element_count": len(self.elements), + "coordinates": "normalized" if use_normalized_coordinates else "absolute" + } + + return json.dumps(visual_state, indent=2) + + def visualize(self) -> Image.Image: + """Create visualization of elements on screenshot. + + Returns: + Image: Annotated screenshot with bounding boxes + """ + if not self.screenshot: + return Image.new('RGB', (800, 600), color='white') + + # Create a copy of the screenshot + img = self.screenshot.copy() + draw = ImageDraw.Draw(img) + + # Draw bounding boxes + for i, element in enumerate(self.elements): + # Generate a different color for each element based on its index + r = (i * 50) % 255 + g = (i * 100) % 255 + b = (i * 150) % 255 + color = (r, g, b) + + # Draw rectangle + draw.rectangle( + [(element.x1, element.y1), (element.x2, element.y2)], + outline=color, + width=2 + ) + + # Draw element identifier + identifier = f"{i}: {element.content[:15]}" + + # Create text background + text_bg_padding = 2 + text_position = (element.x1, element.y1 - 20) + draw.rectangle( + [ + (text_position[0] - text_bg_padding, text_position[1] - text_bg_padding), + (text_position[0] + len(identifier) * 7, text_position[1] + 15) + ], + fill=(255, 255, 255, 180) + ) + + # Draw text + draw.text( + text_position, + identifier, + fill=color + ) + + return img + + +class OmniMCP: + """Main OmniMCP class implementing Model Context Protocol.""" + + def __init__( + self, + server_url: Optional[str] = None, + claude_api_key: Optional[str] = None, + use_normalized_coordinates: bool = False + ): + """Initialize OmniMCP. + + Args: + server_url: URL of OmniParser server + claude_api_key: API key for Claude (overrides config) + use_normalized_coordinates: If True, use normalized (0-1) coordinates + """ + self.omniparser = OmniParserProvider(server_url) + self.visual_state = VisualState() + self.claude_api_key = claude_api_key or config.ANTHROPIC_API_KEY + self.use_normalized_coordinates = use_normalized_coordinates + + # Initialize controllers for keyboard and mouse + self.keyboard_controller = keyboard.Controller() + self.mouse_controller = mouse.Controller() + + # Get screen dimensions from a screenshot + initial_screenshot = utils.take_screenshot() + self.screen_width, self.screen_height = initial_screenshot.size + logger.info(f"Screen dimensions: {self.screen_width}x{self.screen_height}") + + # Ensure OmniParser is running + if not self.omniparser.is_available(): + logger.info("OmniParser not available, attempting to deploy...") + self.omniparser.deploy() + + def update_visual_state(self) -> VisualState: + """Take screenshot and update visual state using OmniParser. + + Returns: + VisualState: Updated visual state + """ + # Take screenshot + screenshot = utils.take_screenshot() + + # Convert to bytes + img_byte_arr = io.BytesIO() + screenshot.save(img_byte_arr, format='PNG') + img_bytes = img_byte_arr.getvalue() + + # Parse with OmniParser + result = self.omniparser.parse_screenshot(img_bytes) + + # Update visual state + self.visual_state.update_from_omniparser(result, screenshot) + + return self.visual_state + + def click(self, x: Union[int, float], y: Union[int, float], button: str = "left") -> None: + """Click at specific coordinates. + + Args: + x: X coordinate (absolute or normalized based on configuration) + y: Y coordinate (absolute or normalized based on configuration) + button: Mouse button ('left', 'right', 'middle') + """ + if self.use_normalized_coordinates: + # Convert normalized coordinates to absolute + x_abs = int(x * self.screen_width) + y_abs = int(y * self.screen_height) + logger.info(f"Clicking at normalized ({x}, {y}) -> absolute ({x_abs}, {y_abs}) with {button} button") + x, y = x_abs, y_abs + else: + logger.info(f"Clicking at ({x}, {y}) with {button} button") + + # Map button string to pynput button object + button_obj = getattr(mouse.Button, button) + + # Move to position and click + self.mouse_controller.position = (x, y) + self.mouse_controller.click(button_obj, 1) + + def move_mouse(self, x: Union[int, float], y: Union[int, float]) -> None: + """Move mouse to coordinates without clicking. + + Args: + x: X coordinate (absolute or normalized) + y: Y coordinate (absolute or normalized) + """ + if self.use_normalized_coordinates: + # Convert normalized coordinates to absolute + x_abs = int(x * self.screen_width) + y_abs = int(y * self.screen_height) + logger.info(f"Moving mouse to normalized ({x}, {y}) -> absolute ({x_abs}, {y_abs})") + x, y = x_abs, y_abs + else: + logger.info(f"Moving mouse to ({x}, {y})") + + # Move to position + self.mouse_controller.position = (x, y) + + def drag_mouse( + self, + start_x: Union[int, float], + start_y: Union[int, float], + end_x: Union[int, float], + end_y: Union[int, float], + button: str = "left", + duration: float = 0.5 + ) -> None: + """Drag mouse from start to end coordinates. + + Args: + start_x: Starting X coordinate + start_y: Starting Y coordinate + end_x: Ending X coordinate + end_y: Ending Y coordinate + button: Mouse button to use for dragging + duration: Duration of drag in seconds + """ + if self.use_normalized_coordinates: + # Convert normalized coordinates to absolute + start_x_abs = int(start_x * self.screen_width) + start_y_abs = int(start_y * self.screen_height) + end_x_abs = int(end_x * self.screen_width) + end_y_abs = int(end_y * self.screen_height) + + logger.info( + f"Dragging from normalized ({start_x}, {start_y}) -> " + f"({end_x}, {end_y}) over {duration}s" + ) + + start_x, start_y = start_x_abs, start_y_abs + end_x, end_y = end_x_abs, end_y_abs + else: + logger.info( + f"Dragging from ({start_x}, {start_y}) -> " + f"({end_x}, {end_y}) over {duration}s" + ) + + # Map button string to pynput button object + button_obj = getattr(mouse.Button, button) + + # Move to start position + self.mouse_controller.position = (start_x, start_y) + + # Press button + self.mouse_controller.press(button_obj) + + # Calculate steps for smooth movement + steps = max(int(duration * 60), 10) # Aim for 60 steps per second, minimum 10 steps + sleep_time = duration / steps + + # Perform drag in steps + for i in range(1, steps + 1): + progress = i / steps + current_x = start_x + (end_x - start_x) * progress + current_y = start_y + (end_y - start_y) * progress + self.mouse_controller.position = (current_x, current_y) + time.sleep(sleep_time) + + # Release button at final position + self.mouse_controller.position = (end_x, end_y) + self.mouse_controller.release(button_obj) + + def scroll(self, amount: int, vertical: bool = True) -> None: + """Scroll the screen. + + Args: + amount: Amount to scroll (positive for up/left, negative for down/right) + vertical: If True, scroll vertically, otherwise horizontally + """ + # pynput's scroll logic: positive values scroll up, negative scroll down + # This is the opposite of pyautogui's convention + scroll_amount = amount + + if vertical: + self.mouse_controller.scroll(0, scroll_amount) + direction = "up" if amount > 0 else "down" + logger.info(f"Scrolled {direction} by {abs(amount)}") + else: + self.mouse_controller.scroll(scroll_amount, 0) + direction = "left" if amount > 0 else "right" + logger.info(f"Scrolled {direction} by {abs(amount)}") + + def scroll_at( + self, + x: Union[int, float], + y: Union[int, float], + amount: int, + vertical: bool = True + ) -> None: + """Scroll at specific coordinates. + + Args: + x: X coordinate + y: Y coordinate + amount: Amount to scroll (positive for down/right, negative for up/left) + vertical: If True, scroll vertically, otherwise horizontally + """ + # First move to the specified position + self.move_mouse(x, y) + + # Then scroll + self.scroll(amount, vertical) + + def click_element( + self, + element_content: str, + button: str = "left", + partial_match: bool = True + ) -> bool: + """Click on element with specified content. + + Args: + element_content: Text content to find + button: Mouse button ('left', 'right', 'middle') + partial_match: If True, match substrings + + Returns: + bool: True if clicked, False if element not found + """ + # Update visual state first + self.update_visual_state() + + # Find element + element = self.visual_state.find_element_by_content(element_content, partial_match) + if not element: + logger.warning(f"Element with content '{element_content}' not found") + return False + + # Click at center of element + if self.use_normalized_coordinates: + self.click(element.normalized_center_x, element.normalized_center_y, button) + else: + self.click(element.center_x, element.center_y, button) + return True + + def type_text(self, text: str) -> None: + """Type text using keyboard. + + This method types a string of text as if typed on the keyboard. + It's useful for entering text into forms, search fields, or documents. + + Args: + text: Text to type + """ + logger.info(f"Typing text: {text}") + self.keyboard_controller.type(text) + + def press_key(self, key: str) -> None: + """Press a single key. + + This method presses and releases a single key. It handles both regular character + keys (like 'a', '5', etc.) and special keys (like 'enter', 'tab', 'escape'). + + Use this method for individual key presses (e.g., pressing Enter to submit a form + or Escape to close a dialog). + + Args: + key: Key to press (e.g., 'a', 'enter', 'tab', 'escape') + + Examples: + press_key('enter') + press_key('tab') + press_key('a') + """ + logger.info(f"Pressing key: {key}") + + # Try to map to a special key if needed + try: + if len(key) == 1: + # Regular character key + self.keyboard_controller.press(key) + self.keyboard_controller.release(key) + else: + # Special key (like enter, tab, etc.) + key_obj = getattr(keyboard.Key, key.lower()) + self.keyboard_controller.press(key_obj) + self.keyboard_controller.release(key_obj) + except (AttributeError, KeyError) as e: + logger.error(f"Unknown key '{key}': {e}") + + def press_hotkey(self, keys: List[str]) -> None: + """Press a hotkey combination (multiple keys pressed simultaneously). + + This method handles keyboard shortcuts like Ctrl+C, Alt+Tab, etc. + It presses all keys in the given list simultaneously, then releases them + in reverse order. + + Unlike press_key() which works with a single key, this method allows + for complex key combinations that must be pressed together. + + Args: + keys: List of keys to press simultaneously (e.g., ['ctrl', 'c']) + + Examples: + press_hotkey(['ctrl', 'c']) # Copy + press_hotkey(['alt', 'tab']) # Switch window + press_hotkey(['ctrl', 'alt', 'delete']) # System operation + """ + logger.info(f"Pressing hotkey: {'+'.join(keys)}") + + key_objects = [] + # First press all modifier keys + for key in keys: + try: + if len(key) == 1: + key_objects.append(key) + else: + key_obj = getattr(keyboard.Key, key.lower()) + key_objects.append(key_obj) + self.keyboard_controller.press(key_objects[-1]) + except (AttributeError, KeyError) as e: + logger.error(f"Unknown key '{key}' in hotkey: {e}") + + # Then release all keys in reverse order + for key_obj in reversed(key_objects): + self.keyboard_controller.release(key_obj) + + async def describe_screen_with_claude(self) -> str: + """Generate a detailed description of the current screen with Claude. + + Returns: + str: Detailed screen description + """ + # Update visual state + self.update_visual_state() + + # Create a system prompt for screen description + system_prompt = """You are an expert UI analyst. +Your task is to provide a detailed description of the user interface shown in the screen. +Focus on: +1. The overall layout and purpose of the screen +2. Key interactive elements and their likely functions +3. Text content and its meaning +4. Hierarchical organization of the interface +5. Possible user actions and workflows + +Be detailed but concise. Organize your description logically.""" + + # Generate a prompt with the visual state and captured screenshot + prompt = f""" +Please analyze this user interface and provide a detailed description. + +Here is the structured data of the UI elements: +```json +{self.visual_state.to_mcp_description(self.use_normalized_coordinates)} +``` + +Describe the overall screen, main elements, and possible interactions a user might perform. +""" + + # Get response from Claude + response = anthropic.prompt( + prompt=prompt, + system_prompt=system_prompt, + api_key=self.claude_api_key + ) + + return response + + async def describe_element_with_claude(self, element: ScreenElement) -> str: + """Generate a detailed description of a specific UI element with Claude. + + Args: + element: The ScreenElement to describe + + Returns: + str: Detailed element description + """ + # Create a system prompt for element description + system_prompt = """You are an expert UI element analyst. +Your task is to provide a detailed description of a specific UI element. +Focus on: +1. The element's type and function +2. Its visual appearance and text content +3. How a user might interact with it +4. Its likely purpose in the interface +5. Any accessibility considerations + +Be detailed but concise.""" + + # Create element details in JSON + element_json = json.dumps({ + "content": element.content, + "type": element.type, + "bounds": { + "x1": element.x1, + "y1": element.y1, + "x2": element.x2, + "y2": element.y2, + "width": element.width, + "height": element.height + }, + "center": { + "x": element.center_x, + "y": element.center_y + }, + "confidence": element.confidence + }, indent=2) + + # Generate a prompt with the element data + prompt = f""" +Please analyze this UI element and provide a detailed description: + +```json +{element_json} +``` + +Describe what this element is, what it does, and how a user might interact with it. +""" + + # Get response from Claude + response = anthropic.prompt( + prompt=prompt, + system_prompt=system_prompt, + api_key=self.claude_api_key + ) + + return response + + def prompt_claude(self, prompt: str, system_prompt: Optional[str] = None) -> str: + """Prompt Claude with the current visual state. + + Args: + prompt: User prompt + system_prompt: Optional system prompt + + Returns: + str: Claude's response + """ + if not self.claude_api_key or self.claude_api_key == "": + logger.warning("Claude API key not set in config or constructor") + + # Update visual state + self.update_visual_state() + + # Create Claude prompt + mcp_description = self.visual_state.to_mcp_description(self.use_normalized_coordinates) + + full_prompt = f""" +Here is a description of the current screen state: +```json +{mcp_description} +``` + +Based on this screen state, {prompt} +""" + + # Default system prompt if not provided + if not system_prompt: + system_prompt = """You are an expert UI assistant that helps users navigate applications. +You have access to a structured description of the current screen through the Model Context Protocol. +Analyze the UI elements and provide clear, concise guidance based on the current screen state.""" + + # Get response from Claude + response = anthropic.prompt( + prompt=full_prompt, + system_prompt=system_prompt, + api_key=self.claude_api_key + ) + + return response + + def execute_natural_language_request(self, request: str) -> str: + """Execute a natural language request by prompting Claude and taking action. + + Args: + request: Natural language request + + Returns: + str: Result description + """ + # Update visual state + self.update_visual_state() + + # Create coordinate format string + coord_format = "normalized (0-1)" if self.use_normalized_coordinates else "absolute (pixels)" + + # Create specialized system prompt for action execution + system_prompt = f"""You are an expert UI automation assistant that helps users control applications. +You have access to a structured description of the current screen through the Model Context Protocol. +Analyze the UI elements and decide what action to take to fulfill the user's request. + +You MUST respond with a JSON object containing the action to perform in the following format: +{{ + "action": "click" | "type" | "press" | "describe", + "params": {{ + // For click action: + "element_content": "text to find", // or + "x": 0.5, // {coord_format} + "y": 0.5, // {coord_format} + "button": "left" | "right" | "middle", + + // For type action: + "text": "text to type", + + // For press action: + "key": "enter" | "tab" | "escape" | etc., + + // For describe action (no additional params needed) + }}, + "reasoning": "Brief explanation of why you chose this action" +}} + +Only return valid JSON. Do not include any other text in your response.""" + + # Prompt Claude for action decision + response = self.prompt_claude( + prompt=f"decide what action to perform to fulfill this request: '{request}'", + system_prompt=system_prompt + ) + + # Parse response + try: + action_data = json.loads(response) + action_type = action_data.get("action", "") + params = action_data.get("params", {}) + reasoning = action_data.get("reasoning", "No reasoning provided") + + logger.info(f"Action: {action_type}, Params: {params}, Reasoning: {reasoning}") + + # Execute action + if action_type == "click": + if "element_content" in params: + success = self.click_element( + params["element_content"], + params.get("button", "left"), + True + ) + if success: + return f"Clicked element: {params['element_content']}" + else: + return f"Failed to find element: {params['element_content']}" + elif "x" in params and "y" in params: + self.click( + params["x"], + params["y"], + params.get("button", "left") + ) + return f"Clicked at coordinates ({params['x']}, {params['y']})" + elif action_type == "type": + self.type_text(params.get("text", "")) + return f"Typed text: {params.get('text', '')}" + elif action_type == "press": + self.press_key(params.get("key", "")) + return f"Pressed key: {params.get('key', '')}" + elif action_type == "describe": + # Just return the reasoning as the description + return reasoning + else: + return f"Unknown action type: {action_type}" + except json.JSONDecodeError: + logger.error(f"Failed to parse Claude response as JSON: {response}") + return "Failed to parse action from Claude response" + except Exception as e: + logger.error(f"Error executing action: {e}") + return f"Error executing action: {str(e)}" + + def run_interactive(self): + """Run command-line interface (CLI) mode. + + This provides a simple prompt where users can enter natural language commands. + Each command is processed by taking a screenshot, analyzing it with OmniParser, + and using Claude to determine and execute the appropriate action. + """ + logger.info("Starting OmniMCP CLI mode") + logger.info(f"Coordinate mode: {'normalized (0-1)' if self.use_normalized_coordinates else 'absolute (pixels)'}") + logger.info("Type 'exit' or 'quit' to exit") + + while True: + request = input("\nEnter command: ") + if request.lower() in ("exit", "quit"): + break + + result = self.execute_natural_language_request(request) + print(f"Result: {result}") + + # Give some time for UI to update before next request + time.sleep(1) + + def save_visual_debug(self, output_path: Optional[str] = None, debug_dir: Optional[str] = None) -> str: + """Save visualization of current visual state for debugging. + + Args: + output_path: Path to save the image. If None, generates a timestamped filename. + debug_dir: Directory to save debug files. If None, uses ~/omnimcp_debug + + Returns: + str: Path to the saved image + """ + # Update visual state + self.update_visual_state() + + # Generate timestamped filename if not provided + if output_path is None: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + + # Use provided debug directory or default + if debug_dir is None: + debug_dir = os.path.join(os.path.expanduser("~"), "omnimcp_debug") + + # Ensure directory exists + os.makedirs(debug_dir, exist_ok=True) + + # Create filename with timestamp + output_path = os.path.join(debug_dir, f"debug_{timestamp}.png") + + # Create visualization and save + vis_img = self.visual_state.visualize() + vis_img.save(output_path) + logger.info(f"Saved visual debug to {output_path}") + + return output_path + + def run_mcp_server(self): + """Run the MCP server for this OmniMCP instance.""" + from openadapt.mcp.server import create_omnimcp_server + + server = create_omnimcp_server(self) + server.run() + + async def run_mcp_server_async(self): + """Run the MCP server asynchronously.""" + from openadapt.mcp.server import create_omnimcp_server + + server = create_omnimcp_server(self) + await server.run_async() + + +def main(): + """Main entry point.""" + fire.Fire(OmniMCP) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/openadapt/run_omnimcp.py b/openadapt/run_omnimcp.py new file mode 100644 index 000000000..7cff8d17e --- /dev/null +++ b/openadapt/run_omnimcp.py @@ -0,0 +1,224 @@ +"""Run OmniMCP with Model Control Protocol. + +This script provides a user-friendly interface to run OmniMCP in different modes. + +OmniMCP combines OmniParser (for visual UI understanding) with the Model Control +Protocol (MCP) to enable Claude to control the computer through natural language. + +Usage: +------ + # Run CLI mode (direct command input) + python -m openadapt.run_omnimcp cli + + # Run MCP server (for Claude Desktop) + python -m openadapt.run_omnimcp server + + # Run in debug mode to visualize screen elements + python -m openadapt.run_omnimcp debug + + # Run with custom OmniParser server URL + python -m openadapt.run_omnimcp server --server-url=http://your-server:8000 + + # Use normalized coordinates (0-1) instead of absolute pixels + python -m openadapt.run_omnimcp cli --use-normalized-coordinates + + # Save debug visualization to specific directory + python -m openadapt.run_omnimcp debug --debug-dir=/path/to/debug/folder + +Components: +---------- +1. OmniParser Client (adapters/omniparser.py): + - Connects to the OmniParser server running on AWS + - Parses screenshots to identify UI elements + +2. OmniMCP Core (omnimcp.py): + - Manages the visual state of the screen + - Provides UI interaction methods (click, type, etc.) + - Implements natural language understanding with Claude + +3. MCP Server (mcp/server.py): + - Implements the Model Control Protocol server + - Exposes UI automation tools to Claude +""" + +import datetime +import os +import sys + +import fire + +from openadapt.omnimcp import OmniMCP +from openadapt.custom_logger import logger + + +class OmniMCPRunner: + """OmniMCP runner with different modes of operation.""" + + def cli( + self, + server_url=None, + claude_api_key=None, + use_normalized_coordinates=False, + debug_dir=None + ): + """Run OmniMCP in CLI mode. + + In CLI mode, you can enter natural language commands directly in the terminal. + OmniMCP will: + 1. Take a screenshot + 2. Analyze it with OmniParser to identify UI elements + 3. Use Claude to decide what action to take based on your command + 4. Execute the action (click, type, etc.) + + This mode is convenient for testing and doesn't require Claude Desktop. + + Args: + server_url: URL of the OmniParser server + claude_api_key: Claude API key (if not provided, uses value from config.py) + use_normalized_coordinates: Use normalized (0-1) coordinates instead of pixels + debug_dir: Directory to save debug visualizations + """ + # Create OmniMCP instance + omnimcp = OmniMCP( + server_url=server_url, + claude_api_key=claude_api_key, # Will use config.ANTHROPIC_API_KEY if None + use_normalized_coordinates=use_normalized_coordinates + ) + + # Handle debug directory if specified + if debug_dir: + os.makedirs(debug_dir, exist_ok=True) + + # Take initial screenshot and save debug visualization + logger.info(f"Saving debug visualization to {debug_dir}") + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + debug_path = os.path.join(debug_dir, f"initial_state_{timestamp}.png") + omnimcp.update_visual_state() + omnimcp.save_visual_debug(debug_path) + + logger.info("Starting OmniMCP in CLI mode") + logger.info(f"Coordinate mode: {'normalized (0-1)' if use_normalized_coordinates else 'absolute (pixels)'}") + + # Run CLI interaction loop + omnimcp.run_interactive() + + def server( + self, + server_url=None, + claude_api_key=None, + use_normalized_coordinates=False, + debug_dir=None + ): + """Run OmniMCP as an MCP server. + + In server mode, OmniMCP provides UI automation tools to Claude through the + Model Control Protocol. The server exposes tools for: + 1. Getting the current screen state with UI elements + 2. Finding UI elements by description + 3. Clicking on elements or coordinates + 4. Typing text and pressing keys + + To use with Claude Desktop: + 1. Configure Claude Desktop to use this server + 2. Ask Claude to perform UI tasks + + Args: + server_url: URL of the OmniParser server + claude_api_key: Claude API key (if not provided, uses value from config.py) + use_normalized_coordinates: Use normalized (0-1) coordinates instead of pixels + debug_dir: Directory to save debug visualizations + """ + # Create OmniMCP instance + omnimcp = OmniMCP( + server_url=server_url, + claude_api_key=claude_api_key, # Will use config.ANTHROPIC_API_KEY if None + use_normalized_coordinates=use_normalized_coordinates + ) + + # Handle debug directory if specified + if debug_dir: + os.makedirs(debug_dir, exist_ok=True) + + # Take initial screenshot and save debug visualization + logger.info(f"Saving debug visualization to {debug_dir}") + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + debug_path = os.path.join(debug_dir, f"initial_state_{timestamp}.png") + omnimcp.update_visual_state() + omnimcp.save_visual_debug(debug_path) + + logger.info("Starting OmniMCP Model Control Protocol server") + logger.info(f"Coordinate mode: {'normalized (0-1)' if use_normalized_coordinates else 'absolute (pixels)'}") + + # Run MCP server + omnimcp.run_mcp_server() + + def debug( + self, + server_url=None, + claude_api_key=None, + use_normalized_coordinates=False, + debug_dir=None + ): + """Run OmniMCP in debug mode. + + Debug mode takes a screenshot, analyzes it with OmniParser, and saves + a visualization showing the detected UI elements with their descriptions. + + This is useful for: + - Understanding what UI elements OmniParser detects + - Debugging issues with element detection + - Fine-tuning OmniParser integration + + Args: + server_url: URL of the OmniParser server + claude_api_key: Claude API key (if not provided, uses value from config.py) + use_normalized_coordinates: Use normalized (0-1) coordinates instead of pixels + debug_dir: Directory to save debug visualizations + """ + # Create OmniMCP instance + omnimcp = OmniMCP( + server_url=server_url, + claude_api_key=claude_api_key, # Will use config.ANTHROPIC_API_KEY if None + use_normalized_coordinates=use_normalized_coordinates + ) + + # Create debug directory if not specified + if not debug_dir: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + debug_dir = os.path.join(os.path.expanduser("~"), "omnimcp_debug", f"debug_{timestamp}") + + os.makedirs(debug_dir, exist_ok=True) + logger.info(f"Saving debug visualization to {debug_dir}") + + # Generate debug filename + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + debug_path = os.path.join(debug_dir, f"screen_state_{timestamp}.png") + + # Update visual state and save debug + logger.info("Taking screenshot and analyzing with OmniParser...") + omnimcp.update_visual_state() + omnimcp.save_visual_debug(debug_path) + logger.info(f"Saved debug visualization to {debug_path}") + + # Print some stats about detected elements + num_elements = len(omnimcp.visual_state.elements) + logger.info(f"Detected {num_elements} UI elements") + + if num_elements > 0: + # Show a few example elements + logger.info("Example elements:") + for i, element in enumerate(omnimcp.visual_state.elements[:5]): + content = element.content[:50] + "..." if len(element.content) > 50 else element.content + logger.info(f" {i+1}. '{content}' at ({element.x1},{element.y1},{element.x2},{element.y2})") + + if num_elements > 5: + logger.info(f" ... and {num_elements - 5} more elements") + + +def main(): + """Main entry point for OmniMCP.""" + fire.Fire(OmniMCPRunner) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/openadapt/strategies/process_graph.py b/openadapt/strategies/process_graph.py new file mode 100644 index 000000000..9af045b30 --- /dev/null +++ b/openadapt/strategies/process_graph.py @@ -0,0 +1,1214 @@ +"""Process graph-based replay strategy using OmniParser and Gemini 2.0. + +This strategy: +1. Uses OmniParser for parsing visual state and Gemini 2.0 for state evaluation +2. Takes natural language task descriptions instead of recording IDs +3. Processes coalesced actions from events.py +4. Builds and maintains a process graph G=(V,E) where: + - V represents States + - E represents Actions + - Graph is constructed before replay based on recording + task description + - Graph is updated during replay based on observed states +""" + +import json +import math +import time +import uuid +from typing import List, Optional, Dict, Union, Literal, Any +import numpy as np + +from pydantic import BaseModel, Field +from PIL import Image +from json_repair import repair_json, loads as repair_loads + +from openadapt import adapters, common, models, utils, vision +from openadapt.custom_logger import logger +from openadapt.db import crud +from openadapt.strategies.base import BaseReplayStrategy +from openadapter.providers.omniparser import OmniParserProvider + + +# Pydantic models for structured data +class RecognitionCriterion(BaseModel): + """Criteria for recognizing a state""" + type: Literal["window_title", "ui_element_present", "visual_template"] + pattern: Optional[str] = None + threshold: Optional[float] = None + element_descriptor: Optional[str] = None + + +class ActionParameter(BaseModel): + """Parameters for an action""" + target_element: Optional[str] = None + text_input: Optional[str] = None + click_type: Optional[Literal["single", "double", "right"]] = None + coordinate_type: Optional[Literal["absolute", "relative"]] = None + + +class ActionModel(BaseModel): + """Model for an action in the process""" + name: str + description: str + parameters: ActionParameter + + +class Condition(BaseModel): + """Condition for a transition""" + type: Literal["element_state", "data_value", "previous_action"] + description: str + + +class Transition(BaseModel): + """Transition between states""" + from_state: str = Field(..., alias="from") + to_state: str = Field(..., alias="to") + action: ActionModel + condition: Optional[Condition] = None + + +class Branch(BaseModel): + """Branch in a decision point""" + condition: str + next_state: str + + +class DecisionPoint(BaseModel): + """Decision point in the process""" + state: str + description: str + branches: List[Branch] + + +class Loop(BaseModel): + """Loop in the process""" + start_state: str + end_state: str + exit_condition: str + description: str + + +class StateModel(BaseModel): + """Model for a state in the process""" + name: str + description: str + recognition_criteria: List[RecognitionCriterion] + + +class ProcessAnalysis(BaseModel): + """Complete model of a process""" + process_name: str + description: str + states: List[StateModel] + transitions: List[Transition] + loops: List[Loop] + decision_points: List[DecisionPoint] + + +class StateTrajectoryEntry(BaseModel): + """Entry in the state trajectory""" + state_name: Optional[str] = None + action_name: Optional[str] = None + timestamp: float + + +class CurrentStateMatch(BaseModel): + """Result of matching current state to graph""" + matched_state_name: str + confidence: float + reasoning: str + + +class UIElement(BaseModel): + """UI element in the visual state""" + type: str + text: Optional[str] = None + bounds: Dict[str, int] + description: str + is_interactive: bool + + +class VisualState(BaseModel): + """Visual state representation""" + window_title: str + ui_elements: List[UIElement] + screenshot_timestamp: float + + +class AbstractState: + """Represents an abstract state in the process graph with recognition logic.""" + + def __init__(self, name, description, recognition_criteria): + self.id = str(uuid.uuid4()) + self.name = name + self.description = description + self.recognition_criteria = recognition_criteria + self.example_screenshots = [] + + def match_rules(self, current_state, trajectory=None): + """Apply rule-based matching using recognition criteria.""" + for criterion in self.recognition_criteria: + if not self._evaluate_criterion(criterion, current_state): + return False + return True + + def _evaluate_criterion(self, criterion, state): + """Evaluate a single recognition criterion against current state.""" + criterion_type = criterion["type"] + + if criterion_type == "window_title": + if not state.window_event or not state.window_event.title: + return False + return criterion["pattern"] in state.window_event.title + + elif criterion_type == "ui_element_present": + if not state.visual_data: + return False + return any( + criterion["element_descriptor"] in element["description"] + for element in state.visual_data + ) + + elif criterion_type == "visual_template": + # Match against example screenshots + if not self.example_screenshots: + return False + return any( + vision.get_image_similarity(state.screenshot.image, example)[0] > criterion.get("threshold", 0.8) + for example in self.example_screenshots + ) + + return False + + def add_example(self, screenshot): + """Add example screenshot for visual matching.""" + self.example_screenshots.append(screenshot) + + def to_dict(self): + """Convert to dictionary representation.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "recognition_criteria": self.recognition_criteria + } + + +class AbstractAction: + """Represents an abstract action with parameters to be instantiated.""" + + def __init__(self, name, description, parameters): + self.id = str(uuid.uuid4()) + self.name = name + self.description = description + self.parameters = parameters + + def to_dict(self): + """Convert to dictionary representation.""" + return { + "id": self.id, + "name": self.name, + "description": self.description, + "parameters": self.parameters + } + + +class ProcessGraph: + """Enhanced process graph with abstract states and conditional transitions.""" + + def __init__(self): + self.nodes = set() + self.edges = [] + self.conditions = {} # Maps (from_state, action, to_state) to condition logic + self.description = "" + + def add_node(self, node): + """Add a node to the graph.""" + self.nodes.add(node) + + def add_edge(self, from_state, action, to_state): + """Add an edge to the graph.""" + self.add_node(from_state) + self.add_node(action) + self.add_node(to_state) + self.edges.append((from_state, action, to_state)) + + def add_condition(self, from_state, action, to_state, condition): + """Add a condition to an edge.""" + key = (from_state.id, action.id, to_state.id) + self.conditions[key] = condition + + def get_abstract_states(self): + """Get all abstract states in the graph.""" + return [node for node in self.nodes if isinstance(node, AbstractState)] + + def get_state_by_name(self, name): + """Find a state by name.""" + for node in self.nodes: + if isinstance(node, AbstractState) and node.name == name: + return node + return None + + def get_possible_actions(self, state): + """Get possible actions from a state, considering conditions.""" + possible_actions = [] + + for from_state, action, to_state in self.edges: + if from_state.id == state.id: + key = (from_state.id, action.id, to_state.id) + if key in self.conditions: + # For now, we include conditional actions + # In a full implementation, would need to evaluate conditions + possible_actions.append((action, to_state)) + else: + possible_actions.append((action, to_state)) + + return possible_actions + + def set_description(self, description): + """Set the overall description of the process.""" + self.description = description + + def get_description(self): + """Get the overall description of the process.""" + return self.description + + def to_model(self) -> ProcessAnalysis: + """Convert graph to Pydantic model for serialization.""" + states = [ + StateModel( + name=state.name, + description=state.description, + recognition_criteria=[ + RecognitionCriterion(**criterion) + for criterion in state.recognition_criteria + ] + ) + for state in self.get_abstract_states() + ] + + transitions = [] + for from_state, action, to_state in self.edges: + if isinstance(from_state, AbstractState) and isinstance(to_state, AbstractState): + key = (from_state.id, action.id, to_state.id) + transition = Transition( + from_state=from_state.name, + to_state=to_state.name, + action=ActionModel( + name=action.name, + description=action.description, + parameters=ActionParameter(**action.parameters) + ) + ) + if key in self.conditions: + transition.condition = Condition(**self.conditions[key]) + transitions.append(transition) + + # Build loops and decision points using a simple algorithm + loops = self._detect_loops() + decision_points = self._detect_decision_points() + + return ProcessAnalysis( + process_name=self.description.split("\n")[0] if self.description else "Unnamed Process", + description=self.description, + states=states, + transitions=transitions, + loops=loops, + decision_points=decision_points + ) + + def _detect_loops(self) -> List[Loop]: + """Simple loop detection algorithm.""" + loops = [] + # Map state names to IDs for easier lookup + state_id_to_name = {state.id: state.name for state in self.get_abstract_states()} + + # Find cycles in the graph using DFS + visited = set() + path = [] + + def dfs(node_id): + if node_id in path: + # Found a cycle + cycle_start = path.index(node_id) + cycle = path[cycle_start:] + # Only process if the cycle involves states (not just actions) + state_ids = [node_id for node_id in cycle if node_id in state_id_to_name] + if len(state_ids) > 1: + loops.append(Loop( + start_state=state_id_to_name[state_ids[0]], + end_state=state_id_to_name[state_ids[-1]], + exit_condition="Condition to exit loop", + description=f"Loop from {state_id_to_name[state_ids[0]]} to {state_id_to_name[state_ids[-1]]}" + )) + return + + if node_id in visited: + return + + visited.add(node_id) + path.append(node_id) + + # Find all outgoing edges + for from_state, _, to_state in self.edges: + if from_state.id == node_id: + dfs(to_state.id) + + path.pop() + + # Start DFS from each state + for state in self.get_abstract_states(): + dfs(state.id) + + return loops + + def _detect_decision_points(self) -> List[DecisionPoint]: + """Detect states with multiple outgoing transitions.""" + decision_points = [] + state_id_to_name = {state.id: state.name for state in self.get_abstract_states()} + + # Count outgoing edges for each state + outgoing_counts = {} + for from_state, _, to_state in self.edges: + if from_state.id not in outgoing_counts: + outgoing_counts[from_state.id] = [] + outgoing_counts[from_state.id].append(to_state.id) + + # States with multiple outgoing edges are decision points + for state_id, destinations in outgoing_counts.items(): + if state_id in state_id_to_name and len(destinations) > 1: + branches = [] + for dest_id in destinations: + if dest_id in state_id_to_name: + branches.append(Branch( + condition=f"Condition to go to {state_id_to_name[dest_id]}", + next_state=state_id_to_name[dest_id] + )) + + if branches: + decision_points.append(DecisionPoint( + state=state_id_to_name[state_id], + description=f"Decision point at {state_id_to_name[state_id]}", + branches=branches + )) + + return decision_points + + def to_json(self): + """Convert graph to JSON string.""" + return self.to_model().model_dump_json(indent=2) + + def update_with_observation(self, observed_state, previous_state, latest_action): + """Update graph with observed state during execution.""" + # Find abstract states that match the observed state + similar_state = None + highest_similarity = 0.0 + + for state in self.get_abstract_states(): + similarity = self._calculate_state_similarity(observed_state, state) + if similarity > highest_similarity: + highest_similarity = similarity + similar_state = state + + # Create a new state if no good match + if highest_similarity < 0.7: + similar_state = self._create_new_state_from_observation(observed_state) + + # If we have a previous state and action, create or update a transition + if previous_state and latest_action: + # Check if transition already exists + transition_exists = False + for from_state, action, to_state in self.edges: + if (from_state.id == previous_state.id and + action.name == latest_action.name and + to_state.id == similar_state.id): + transition_exists = True + break + + if not transition_exists: + # Create a new abstract action from the latest action + action = AbstractAction( + name=latest_action.name, + description=f"Action {latest_action.name}", + parameters=self._extract_action_parameters(latest_action) + ) + + # Add the edge + self.add_edge(previous_state, action, similar_state) + + return similar_state + + def _calculate_state_similarity(self, observed_state, abstract_state): + """Calculate similarity between observed state and abstract state.""" + # Use rule-based matching first + if abstract_state.match_rules(observed_state): + return 0.9 # High confidence if rules match + + # Fall back to visual similarity if we have example screenshots + if abstract_state.example_screenshots and observed_state.screenshot: + visual_similarities = [ + vision.get_image_similarity(observed_state.screenshot.image, example)[0] + for example in abstract_state.example_screenshots + ] + return max(visual_similarities) if visual_similarities else 0.0 + + return 0.0 + + def _create_new_state_from_observation(self, observed_state): + """Create a new abstract state from an observed state.""" + # Generate a name for the state based on window title + name = "State_" + str(len(self.get_abstract_states()) + 1) + if observed_state.window_event and observed_state.window_event.title: + name = f"State_{observed_state.window_event.title[:20]}" + + # Create recognition criteria + criteria = [] + if observed_state.window_event and observed_state.window_event.title: + criteria.append({ + "type": "window_title", + "pattern": observed_state.window_event.title + }) + + if observed_state.visual_data: + # Add criteria based on visible UI elements + for element in observed_state.visual_data[:3]: # Limit to a few key elements + if element.get("description"): + criteria.append({ + "type": "ui_element_present", + "element_descriptor": element["description"] + }) + + state = AbstractState( + name=name, + description=f"State with window title: {observed_state.window_event.title if observed_state.window_event else 'Unknown'}", + recognition_criteria=criteria + ) + + # Add screenshot as example for visual matching + if observed_state.screenshot: + state.add_example(observed_state.screenshot.image) + + self.add_node(state) + return state + + def _extract_action_parameters(self, action_event): + """Extract parameters from an action event.""" + parameters = {} + + if action_event.name in common.MOUSE_EVENTS: + parameters["target_element"] = action_event.active_segment_description + if "click" in action_event.name: + if "double" in action_event.name: + parameters["click_type"] = "double" + elif "right" in action_event.name: + parameters["click_type"] = "right" + else: + parameters["click_type"] = "single" + + elif action_event.name in common.KEYBOARD_EVENTS: + if action_event.text: + parameters["text_input"] = action_event.text + + return parameters + + +class State: + """Represents a concrete state during execution.""" + + def __init__(self, screenshot, window_event, browser_event=None, visual_data=None): + self.id = str(uuid.uuid4()) + self.screenshot = screenshot + self.window_event = window_event + self.browser_event = browser_event + self.visual_data = visual_data or [] + + +class ProcessGraphStrategy(BaseReplayStrategy): + """Strategy using process graphs, OmniParser and Gemini 2.0 Flash.""" + + def __init__( + self, + task_description: str, + recording_id: int = None, + ) -> None: + """Initialize with task description rather than recording ID.""" + # Find best matching recording if not provided + if not recording_id: + recording_id = self._find_matching_recording(task_description) + + db_session = crud.get_new_session() + self.recording = crud.get_recording(db_session, recording_id) + super().__init__(self.recording) + + self.task_description = task_description + + # Initialize OmniParser service + self.omniparser_provider = OmniParserProvider() + self._ensure_omniparser_running() + + # Initialize tracking + self.state_action_history = [] # List of (state, action) pairs + self.action_history = [] + self.current_state = None + self.current_abstract_state = None + + # Build graph before replay + self.process_graph = self._build_generalizable_process_graph(task_description) + + def _ensure_omniparser_running(self): + """Ensure OmniParser is running, deploying if necessary.""" + status = self.omniparser_provider.status() + if not status['services']: + logger.info("Deploying OmniParser...") + self.omniparser_provider.deploy() + self.omniparser_provider.stack.create_service() + + def _find_matching_recording(self, task_description: str) -> int: + """Find recording with most similar task description using vector similarity.""" + db_session = crud.get_new_session() + recordings = crud.get_all_recordings(db_session) + best_match = None + highest_similarity = -1 + + for recording in recordings: + if not recording.task_description: + continue + + similarity = self._calculate_text_similarity(task_description, recording.task_description) + if similarity > highest_similarity: + highest_similarity = similarity + best_match = recording.id + + if best_match is None: + # If no good match, use the most recent recording + recordings_sorted = sorted(recordings, key=lambda r: r.timestamp, reverse=True) + if recordings_sorted: + best_match = recordings_sorted[0].id + else: + raise ValueError("No recordings found in the database.") + + return best_match + + def _calculate_text_similarity(self, text1, text2): + """Calculate similarity between two text strings.""" + # Simple word overlap similarity + if not text1 or not text2: + return 0.0 + + words1 = set(text1.lower().split()) + words2 = set(text2.lower().split()) + + if not words1 or not words2: + return 0.0 + + intersection = words1.intersection(words2) + union = words1.union(words2) + + return len(intersection) / len(union) + + def _build_generalizable_process_graph(self, task_description): + """Build a generalizable process graph using multi-phase approach with MMMs.""" + # Get coalesced actions + processed_actions = self.recording.processed_action_events + + # Phase 1: Process Understanding - Analyze the entire workflow + process_model = self._analyze_entire_process(processed_actions, task_description) + + # Phase 2: Graph Construction - Build abstract graph from understanding + initial_graph = self._construct_abstract_graph(process_model) + + # Phase 3: Graph Validation - Test and refine by walking through recording + refined_graph = self._validate_and_refine_graph(initial_graph, processed_actions) + + return refined_graph + + def _select_representative_screenshots(self, action_events, max_images=10): + """Select representative screenshots from the action events.""" + if not action_events: + return [] + + # If few actions, use all screenshots + if len(action_events) <= max_images: + return [action.screenshot.image for action in action_events if action.screenshot] + + # Otherwise, select evenly spaced screenshots + step = len(action_events) // max_images + selected_actions = action_events[::step] + + # Add the last action if not included + if action_events[-1] not in selected_actions: + selected_actions.append(action_events[-1]) + + return [action.screenshot.image for action in selected_actions if action.screenshot] + + def _analyze_entire_process(self, actions, task_description): + """Have Gemini analyze the entire recording to understand the process structure.""" + key_screenshots = self._select_representative_screenshots(actions) + + # Generate schema JSON for the prompt + schema_json = ProcessAnalysis.model_json_schema() + + system_prompt = "You are an expert in understanding user interface workflows." + prompt = f""" + Analyze this UI automation sequence and identify: + 1. The high-level steps in the process + 2. Any repetitive patterns or loops + 3. Decision points where the workflow might branch + 4. The semantic meaning of each major state + + Task description: {task_description} + + RESPOND USING THE FOLLOWING JSON SCHEMA: + ```json + {json.dumps(schema_json, indent=2)} + ``` + + Your response must strictly follow this schema and be valid JSON. + """ + + process_analysis_text = self.prompt_gemini(prompt, system_prompt, key_screenshots) + + # Use json_repair for robust parsing + try: + # Direct parsing if possible + process_data = repair_loads(process_analysis_text) + process_model = ProcessAnalysis(**process_data) + return process_model + except Exception as e: + logger.warning(f"Initial JSON parsing failed: {e}") + + # Try to repair potentially broken JSON + try: + repaired_json = repair_json(process_analysis_text, ensure_ascii=False) + process_data = json.loads(repaired_json) + process_model = ProcessAnalysis(**process_data) + return process_model + except Exception as repair_e: + logger.error(f"JSON repair also failed: {repair_e}") + + # Last resort: try direct object return + try: + process_data = repair_json(process_analysis_text, return_objects=True) + process_model = ProcessAnalysis(**process_data) + return process_model + except Exception as final_e: + logger.error(f"All JSON parsing methods failed: {final_e}") + return self._fallback_process_analysis(actions, task_description) + + def _fallback_process_analysis(self, actions, task_description): + """Create a simple process model if all parsing fails.""" + logger.warning("Using fallback process analysis") + + # Create a simple linear process model + states = [] + transitions = [] + + # Create a state for each key action + key_actions = actions[::max(1, len(actions) // 5)] # At most 5 states + + for i, action in enumerate(key_actions): + state_name = f"State_{i+1}" + state_description = f"State after {action.name} action" + + # Create recognition criteria + criteria = [] + if action.window_event and action.window_event.title: + criteria.append({ + "type": "window_title", + "pattern": action.window_event.title + }) + + states.append(StateModel( + name=state_name, + description=state_description, + recognition_criteria=criteria + )) + + # Create transition to next state + if i < len(key_actions) - 1: + next_action = key_actions[i+1] + transitions.append(Transition( + from_state=state_name, + to_state=f"State_{i+2}", + action=ActionModel( + name=next_action.name, + description=f"{next_action.name} action", + parameters=ActionParameter() + ) + )) + + return ProcessAnalysis( + process_name="Fallback Process", + description=f"Fallback process for task: {task_description}", + states=states, + transitions=transitions, + loops=[], + decision_points=[] + ) + + def _construct_abstract_graph(self, process_model): + """Construct an abstract process graph based on the process understanding.""" + graph = ProcessGraph() + graph.set_description(process_model.description) + + # Create abstract state definitions based on process model + for state_def in process_model.states: + state = AbstractState( + name=state_def.name, + description=state_def.description, + recognition_criteria=[criterion.model_dump() for criterion in state_def.recognition_criteria] + ) + graph.add_node(state) + + # Create transitions with abstract actions + for transition in process_model.transitions: + from_state = graph.get_state_by_name(transition.from_state) + to_state = graph.get_state_by_name(transition.to_state) + + if from_state and to_state: + action = AbstractAction( + name=transition.action.name, + description=transition.action.description, + parameters=transition.action.parameters.model_dump() + ) + graph.add_edge(from_state, action, to_state) + + # Add conditional branches if present + if transition.condition: + graph.add_condition(from_state, action, to_state, transition.condition.model_dump()) + + return graph + + def _validate_and_refine_graph(self, graph, actions): + """Test the graph against recorded actions and refine it with Gemini's help.""" + # Simulate walking through the recording using the graph + simulation_results = self._simulate_graph_execution(graph, actions) + + if simulation_results["success"]: + return graph + + # If simulation failed, ask Gemini to refine the graph + system_prompt = "You are an expert in refining process models." + prompt = f""" + The process graph failed to match the recording at these points: + {simulation_results["failures"]} + + Current graph: {graph.to_json()} + + Please refine the graph to better match the recorded process while + maintaining generalizability. Consider: + 1. Adding missing states or transitions + 2. Adjusting state recognition criteria + 3. Modifying action parameters + 4. Adding conditional logic + + RESPOND USING THE SAME JSON SCHEMA AS THE CURRENT GRAPH. + """ + + refinements_text = self.prompt_gemini(prompt, system_prompt, simulation_results["screenshots"]) + + try: + refinements_data = repair_loads(refinements_text) + refined_model = ProcessAnalysis(**refinements_data) + refined_graph = self._construct_abstract_graph(refined_model) + + # Check if refinement improved the simulation + new_failures = len(self._simulate_graph_execution(refined_graph, actions)["failures"]) + old_failures = len(simulation_results["failures"]) + + if new_failures < old_failures: + return refined_graph + return graph + + except Exception as e: + logger.error(f"Failed to parse graph refinements: {e}") + return graph + + def _simulate_graph_execution(self, graph, actions): + """Simulate executing the graph with the recorded actions.""" + failures = [] + screenshots = [] + current_state = None + + for i, action in enumerate(actions): + # If first action, find initial state + if i == 0: + state = State(action.screenshot, action.window_event, action.browser_event) + matched_state = None + highest_similarity = 0.0 + + for abstract_state in graph.get_abstract_states(): + similarity = graph._calculate_state_similarity(state, abstract_state) + if similarity > highest_similarity: + highest_similarity = similarity + matched_state = abstract_state + + if highest_similarity < 0.7: + failures.append(f"Failed to match initial state at action {i}") + screenshots.append(action.screenshot.image) + + current_state = matched_state + continue + + # For subsequent actions, check if the graph has a transition + if current_state: + possible_actions = graph.get_possible_actions(current_state) + + # Check if any action matches the recorded action + action_match = False + for graph_action, next_state in possible_actions: + if graph_action.name == action.name: + action_match = True + current_state = next_state + break + + if not action_match: + failures.append(f"No matching action '{action.name}' from state '{current_state.name}' at action {i}") + screenshots.append(action.screenshot.image) + + return { + "success": len(failures) == 0, + "failures": failures, + "screenshots": screenshots[:5] # Limit to 5 screenshots for prompt size + } + + def get_next_action_event( + self, + screenshot: models.Screenshot, + window_event: models.WindowEvent, + ) -> models.ActionEvent: + """Determine next action using the process graph and runtime adaptation.""" + # Create current state representation + current_state = State( + screenshot=screenshot, + window_event=window_event + ) + + # Parse visual state with OmniParser + visual_data = self._parse_state_with_omniparser(screenshot.image) + current_state.visual_data = visual_data + + # Update graph with actual observed state + previous_abstract_state = self.current_abstract_state + latest_action = self.action_history[-1] if self.action_history else None + + self.current_abstract_state = self.process_graph.update_with_observation( + current_state, + previous_abstract_state, + latest_action + ) + + self.current_state = current_state + + # Find possible next actions in graph + possible_actions = self.process_graph.get_possible_actions(self.current_abstract_state) + + if not possible_actions: + # No actions available - either reached end state or unexpected state + if len(self.action_history) > 0: + # We've taken at least one action, so this might be the end + raise StopIteration("No further actions available in the process graph") + else: + # No actions taken yet - generate one with Gemini + next_action = self._generate_action_with_gemini() + self.action_history.append(next_action) + return next_action + + if len(possible_actions) == 1: + # Single clear action to take + action, next_state = possible_actions[0] + next_action = self._instantiate_abstract_action(action, current_state) + else: + # Multiple possible actions - use Gemini to decide + next_action = self._decide_between_actions(possible_actions, current_state) + + self.state_action_history.append((self.current_abstract_state, next_action)) + self.action_history.append(next_action) + + return next_action + + def _parse_state_with_omniparser(self, screenshot_image): + """Use OmniParser to parse the visual state.""" + try: + # Convert PIL Image to bytes + import io + img_byte_arr = io.BytesIO() + screenshot_image.save(img_byte_arr, format='PNG') + img_bytes = img_byte_arr.getvalue() + + # Call OmniParser API + result = self.omniparser_provider.parse_screenshot(img_bytes) + + # Transform the result into our expected format + ui_elements = [] + for element in result.get("elements", []): + ui_elements.append({ + "type": element.get("type", "unknown"), + "text": element.get("text", ""), + "bounds": element.get("bounds", {"x": 0, "y": 0, "width": 0, "height": 0}), + "description": element.get("description", ""), + "is_interactive": element.get("is_interactive", False) + }) + + return ui_elements + + except Exception as e: + logger.error(f"Error parsing state with OmniParser: {e}") + return [] + + def _instantiate_abstract_action(self, abstract_action, current_state): + """Convert abstract action to concrete ActionEvent based on current state.""" + try: + # Use parameters from abstract action if possible + params = abstract_action.parameters + + if abstract_action.name in common.MOUSE_EVENTS: + # Create a mouse action + action_event = models.ActionEvent( + name=abstract_action.name, + screenshot=current_state.screenshot, + window_event=current_state.window_event, + recording=self.recording + ) + + # If we have a target element, find its coordinates + if params.get("target_element"): + target_element = None + for element in current_state.visual_data: + if params["target_element"] in element.get("description", ""): + target_element = element + break + + if target_element: + bounds = target_element.get("bounds", {}) + # Calculate center of element + center_x = bounds.get("x", 0) + bounds.get("width", 0) / 2 + center_y = bounds.get("y", 0) + bounds.get("height", 0) / 2 + + action_event.mouse_x = center_x + action_event.mouse_y = center_y + action_event.active_segment_description = params["target_element"] + else: + # If target not found, use Gemini to identify coordinates + action_event = self._locate_target_with_gemini( + params["target_element"], + abstract_action.name, + current_state + ) + else: + # Use Gemini to decide where to click + action_event = self._locate_target_with_gemini( + None, + abstract_action.name, + current_state + ) + + return action_event + + elif abstract_action.name in common.KEYBOARD_EVENTS: + # Create a keyboard action + action_event = models.ActionEvent( + name=abstract_action.name, + screenshot=current_state.screenshot, + window_event=current_state.window_event, + recording=self.recording + ) + + if params.get("text_input"): + # For "type" action, convert to actual keypresses + action_event = models.ActionEvent.from_dict({ + "name": "type", + "text": params["text_input"] + }) + action_event.screenshot = current_state.screenshot + action_event.window_event = current_state.window_event + action_event.recording = self.recording + + return action_event + + else: + # For other actions, use Gemini + return self._generate_action_with_gemini(abstract_action.name) + + except Exception as e: + logger.error(f"Error instantiating action: {e}") + return self._generate_action_with_gemini() + + def _locate_target_with_gemini(self, target_description, action_name, current_state): + """Use Gemini to locate a target on the screen.""" + system_prompt = "You are an expert in UI automation and element identification." + prompt = f""" + Identify the coordinates to perform a {action_name} action. + + {f'The target is described as: {target_description}' if target_description else 'Find the most appropriate element to interact with based on the current state.'} + + Analyze the screenshot and provide the x,y coordinates where the action should be performed. + Respond with a JSON object containing: + 1. x: the x-coordinate (number) + 2. y: the y-coordinate (number) + 3. description: brief description of what element is at these coordinates + """ + + result_text = self.prompt_gemini(prompt, system_prompt, [current_state.screenshot.image]) + + try: + # Parse the response + coord_data = repair_loads(result_text) + + # Create action event + action_event = models.ActionEvent( + name=action_name, + screenshot=current_state.screenshot, + window_event=current_state.window_event, + recording=self.recording, + mouse_x=coord_data.get("x", 0), + mouse_y=coord_data.get("y", 0), + active_segment_description=coord_data.get("description", "") + ) + + return action_event + + except Exception as e: + logger.error(f"Error parsing coordinates: {e}") + + # Fallback: use center of screen + window_width = current_state.window_event.width if current_state.window_event else 800 + window_height = current_state.window_event.height if current_state.window_event else 600 + + return models.ActionEvent( + name=action_name, + screenshot=current_state.screenshot, + window_event=current_state.window_event, + recording=self.recording, + mouse_x=window_width / 2, + mouse_y=window_height / 2, + active_segment_description="Center of screen (fallback)" + ) + + def _decide_between_actions(self, possible_actions, current_state): + """Use Gemini to decide between multiple possible actions.""" + system_prompt = "You are an expert in UI automation decision making." + + actions_list = [] + for i, (action, next_state) in enumerate(possible_actions): + actions_list.append({ + "id": i, + "name": action.name, + "description": action.description, + "parameters": action.parameters, + "next_state": next_state.name, + "next_state_description": next_state.description + }) + + prompt = f""" + Decide which action to take next based on the current state and task description. + + Task description: {self.task_description} + Current state: {self.current_abstract_state.description if self.current_abstract_state else "Initial state"} + + Possible actions: + {json.dumps(actions_list, indent=2)} + + Respond with a JSON object containing: + 1. chosen_action_id: the ID of the chosen action (number) + 2. reasoning: brief explanation for your choice + """ + + result_text = self.prompt_gemini(prompt, system_prompt, [current_state.screenshot.image]) + + try: + result = repair_loads(result_text) + chosen_id = result.get("chosen_action_id", 0) + chosen_id = min(chosen_id, len(possible_actions) - 1) # Ensure valid index + + action, next_state = possible_actions[chosen_id] + return self._instantiate_abstract_action(action, current_state) + + except Exception as e: + logger.error(f"Error deciding between actions: {e}") + # Default to first action + action, next_state = possible_actions[0] + return self._instantiate_abstract_action(action, current_state) + + def _generate_action_with_gemini(self, suggested_action_name=None): + """Generate action with Gemini if graph doesn't provide one.""" + system_prompt = "You are an expert in UI automation." + + trajectory = [] + for i, (state, action) in enumerate(self.state_action_history[-5:]): + trajectory.append({ + "step": i + 1, + "state": state.name if state else "Unknown", + "action": action.name if action else "None" + }) + + prompt = f""" + Generate the next action to perform based on: + + Task description: {self.task_description} + Recent trajectory: {json.dumps(trajectory, indent=2)} + {f'Suggested action type: {suggested_action_name}' if suggested_action_name else ''} + + Analyze the screenshot and respond with a JSON object for the next ActionEvent: + {{ + "name": "click|move|scroll|type", + "mouse_x": number, + "mouse_y": number, + "text": "text to type (for keyboard actions)", + "active_segment_description": "description of what's being clicked" + }} + + Only include relevant fields based on the action type. + """ + + result_text = self.prompt_gemini( + prompt, + system_prompt, + [self.current_state.screenshot.image] if self.current_state else [] + ) + + try: + action_dict = repair_loads(result_text) + action = models.ActionEvent.from_dict(action_dict) + + # Add missing context + action.screenshot = self.current_state.screenshot if self.current_state else None + action.window_event = self.current_state.window_event if self.current_state else None + action.recording = self.recording + + return action + + except Exception as e: + logger.error(f"Error generating action: {e}") + + # Create a fallback action - simple click in the center + window_width = self.current_state.window_event.width if self.current_state and self.current_state.window_event else 800 + window_height = self.current_state.window_event.height if self.current_state and self.current_state.window_event else 600 + + return models.ActionEvent( + name="click", + screenshot=self.current_state.screenshot if self.current_state else None, + window_event=self.current_state.window_event if self.current_state else None, + recording=self.recording, + mouse_x=window_width / 2, + mouse_y=window_height / 2, + mouse_button_name="left", + active_segment_description="Center of screen (fallback)" + ) + + def prompt_gemini(self, prompt, system_prompt, images): + """Helper method to prompt Gemini with images.""" + from openadapt.drivers import google + return google.prompt( + prompt, + system_prompt=system_prompt, + images=images, + model_name="models/gemini-1.5-pro-latest" + ) + + def __del__(self): + """Clean up OmniParser service when done.""" + try: + self.omniparser_provider.stack.stop_service() + except: + pass \ No newline at end of file