Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ API Reference
fetch_examples
fetch_forbidden_questions_dataset
fetch_harmbench_dataset
fetch_harmbench_multimodal_dataset_async
fetch_librAI_do_not_answer_dataset
fetch_llm_latent_adversarial_training_harmful_dataset
fetch_jbb_behaviors_by_harm_category
Expand Down
2 changes: 2 additions & 0 deletions pyrit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pyrit.datasets.dataset_helper import fetch_examples
from pyrit.datasets.forbidden_questions_dataset import fetch_forbidden_questions_dataset
from pyrit.datasets.harmbench_dataset import fetch_harmbench_dataset
from pyrit.datasets.harmbench_multimodal_dataset import fetch_harmbench_multimodal_dataset_async
from pyrit.datasets.librAI_do_not_answer_dataset import fetch_librAI_do_not_answer_dataset
from pyrit.datasets.llm_latent_adversarial_training_harmful_dataset import (
fetch_llm_latent_adversarial_training_harmful_dataset,
Expand Down Expand Up @@ -47,6 +48,7 @@
"fetch_examples",
"fetch_forbidden_questions_dataset",
"fetch_harmbench_dataset",
"fetch_harmbench_multimodal_dataset_async",
"fetch_librAI_do_not_answer_dataset",
"fetch_llm_latent_adversarial_training_harmful_dataset",
"fetch_many_shot_jailbreaking_dataset",
Expand Down
200 changes: 200 additions & 0 deletions pyrit/datasets/harmbench_multimodal_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import uuid
from enum import Enum
from pathlib import Path
from typing import List, Literal, Optional

from pyrit.common.net_utility import make_request_and_raise_if_error_async
from pyrit.datasets.dataset_helper import fetch_examples
from pyrit.models import SeedPromptDataset, data_serializer_factory
from pyrit.models.seed_prompt import SeedPrompt

logger = logging.getLogger(__name__)


class SemanticCategory(Enum):
CYBERCRIME_INTRUSION = "cybercrime_intrusion" # n=54
ILLEGAL = "illegal" # 36
HARMFUL = "harmful" # 9
CHEMICAL_BIOLOGICAL = "chemical_biological" # 4
HARASSMENT_BULLYING = "harassment_bullying" # 4
MISINFORMATION_DISINFORMATION = "misinformation_disinformation" # 3


async def fetch_harmbench_multimodal_dataset_async(
*,
source: str = (
"https://raw.githubusercontent.com/centerforaisafety/HarmBench/c0423b9/data/behavior_datasets/"
"harmbench_behaviors_multimodal_all.csv"
),
source_type: Literal["public_url", "file"] = "public_url",
cache: bool = True,
data_home: Optional[Path] = None,
categories: Optional[List[SemanticCategory]] = None,
) -> SeedPromptDataset:
"""
Fetch HarmBench multimodal examples and create a SeedPromptDataset.

The HarmBench multimodal dataset contains 110 harmful behaviors.
Each example consists of an image ("image_path") and a behavior string referencing the image ("text").
The text and image prompts that belong to the same example are linked using the same ``prompt_group_id``.
You can extract the grouped prompts using the ``group_seed_prompts_by_prompt_group_id`` method.

Note: The first call may be slow as images need to be downloaded from the remote repository.
Subsequent calls will be faster since images are cached locally and won't need to be re-downloaded.

Args:
source (str): The source from which to fetch examples. Defaults to the HarmBench repository.
source_type (Literal["public_url", "file"]): The type of source. Defaults to 'public_url'.
cache (bool): Whether to cache the fetched examples. Defaults to True.
data_home (Optional[Path]): Directory to store cached data. Defaults to None.
categories (Optional[List[SemanticCategory]]): List of semantic categories
to filter examples. If None, all categories are included (default).

Returns:
SeedPromptDataset: A SeedPromptDataset containing the multimodal examples.

Raises:
ValueError: If any of the specified categories are invalid.

Note:
For more information related to the HarmBench project and the original dataset, visit:
https://www.harmbench.org/ \n
Paper: https://arxiv.org/abs/2402.04249 \n
Authors:
Mantas Mazeika & Long Phan & Xuwang Yin & Andy Zou & Zifan Wang & Norman Mu & Elham Sakhaee
& Nathaniel Li & Steven Basart & Bo Li & David Forsyth & Dan Hendrycks
"""
if categories is not None:
valid_categories = {category.value for category in SemanticCategory}
invalid_categories = (
set(cat.value if isinstance(cat, SemanticCategory) else cat for cat in categories) - valid_categories
)
if invalid_categories:
raise ValueError(f"Invalid semantic categories: {', '.join(invalid_categories)}")

required_keys = {"Behavior", "BehaviorID", "FunctionalCategory", "SemanticCategory", "ImageFileName"}
examples = fetch_examples(source, source_type, cache, data_home)
prompts = []
failed_image_count = 0

for example in examples:
missing_keys = required_keys - example.keys()
if missing_keys:
raise ValueError(f"Missing keys in example: {', '.join(missing_keys)}")

if example["FunctionalCategory"] != "multimodal":
continue

semantic_category = example["SemanticCategory"]

if categories is not None:
category_values = {cat.value for cat in categories}
if semantic_category not in category_values:
continue

behavior_text = example["Behavior"]
behavior_id = example["BehaviorID"]
image_filename = example["ImageFileName"]
image_description = example.get("ImageDescription", "")
redacted_description = example.get("RedactedImageDescription", "")

# A unique group ID to link the text and image prompts since they are part of the same example
group_id = uuid.uuid4()

# Note: All images in the HarmBench dataset are stored as .png files, even if the ImageFileName
# field specifies a different extension (.jpg or .jpeg). Hence we always use .png extension here.
image_url = (
"https://raw.githubusercontent.com/centerforaisafety/HarmBench/c0423b9/data/multimodal_behavior_images/"
f"{image_filename.rsplit('.', 1)[0]}.png"
)

try:
# Only include examples where image fetch is successful
local_image_path = await _fetch_and_save_image_async(image_url, behavior_id)

image_prompt = SeedPrompt(
value=local_image_path,
data_type="image_path",
name=f"HarmBench Multimodal Image - {behavior_id}",
dataset_name="HarmBench Multimodal Examples",
harm_categories=[semantic_category],
description=f"An image prompt from the HarmBench multimodal dataset, BehaviorID: {behavior_id}",
source=source,
prompt_group_id=group_id,
sequence=0,
metadata={
"behavior_id": behavior_id,
"image_description": image_description,
"redacted_image_description": redacted_description,
"original_image_url": image_url,
},
)
prompts.append(image_prompt)
except Exception as e:
failed_image_count += 1
logger.warning(f"Failed to fetch image for behavior {behavior_id}: {e}. Skipping this example.")
else:
text_prompt = SeedPrompt(
value=behavior_text,
data_type="text",
name=f"HarmBench Multimodal Text - {behavior_id}",
dataset_name="HarmBench Multimodal Examples",
harm_categories=[semantic_category],
description=(f"A text prompt from the HarmBench multimodal dataset, BehaviorID: {behavior_id}"),
source=source,
prompt_group_id=group_id,
sequence=0,
metadata={
"behavior_id": behavior_id,
},
authors=[
"Mantas Mazeika",
"Long Phan",
"Xuwang Yin",
"Andy Zou",
"Zifan Wang",
"Norman Mu",
"Elham Sakhaee",
"Nathaniel Li",
"Steven Basart",
"Bo Li",
"David Forsyth",
"Dan Hendrycks",
],
groups=[
"University of Illinois Urbana-Champaign",
"Center for AI Safety",
"Carnegie Mellon University",
"UC Berkeley",
"Microsoft",
],
)
prompts.append(text_prompt)

if failed_image_count > 0:
logger.warning(f"Total skipped examples: {failed_image_count} (image fetch failures)")

seed_prompt_dataset = SeedPromptDataset(prompts=prompts)
return seed_prompt_dataset


async def _fetch_and_save_image_async(image_url: str, behavior_id: str) -> str:
filename = f"harmbench_{behavior_id}.png"
serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png")

# Return existing path if image already exists for this BehaviorID
serializer.value = str(serializer._memory.results_path + serializer.data_sub_directory + f"/{filename}")
try:
if await serializer._memory.results_storage_io.path_exists(serializer.value):
return serializer.value
except Exception as e:
logger.warning(f"Failed to check whether image for {behavior_id} already exists: {e}")

response = await make_request_and_raise_if_error_async(endpoint_uri=image_url, method="GET")
await serializer.save_data(data=response.content, output_filename=filename.replace(".png", ""))

return str(serializer.value)
16 changes: 16 additions & 0 deletions tests/integration/datasets/test_fetch_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
fetch_equitymedqa_dataset_unique_values,
fetch_forbidden_questions_dataset,
fetch_harmbench_dataset,
fetch_harmbench_multimodal_dataset_async,
fetch_jbb_behaviors_by_harm_category,
fetch_jbb_behaviors_by_jbb_category,
fetch_jbb_behaviors_dataset,
Expand Down Expand Up @@ -72,6 +73,21 @@ def test_fetch_datasets(fetch_function, is_seed_prompt_dataset):
assert len(data.prompts) > 0


@pytest.mark.asyncio
@pytest.mark.parametrize(
"fetch_function, number_of_prompts",
[
(fetch_harmbench_multimodal_dataset_async, 110 * 2),
],
)
async def test_fetch_multimodal_datasets(fetch_function, number_of_prompts):
data = await fetch_function()

assert data is not None
assert isinstance(data, SeedPromptDataset)
assert len(data.prompts) == number_of_prompts


@pytest.mark.integration
def test_fetch_jbb_behaviors_by_harm_category():
"""Integration test for filtering by harm category with real data."""
Expand Down
Loading