|
| 1 | +# --- |
| 2 | +# jupyter: |
| 3 | +# jupytext: |
| 4 | +# text_representation: |
| 5 | +# extension: .py |
| 6 | +# format_name: percent |
| 7 | +# format_version: '1.3' |
| 8 | +# jupytext_version: 1.17.2 |
| 9 | +# --- |
| 10 | + |
| 11 | +# %% [markdown] |
| 12 | +# # Querying by Harm Categories |
| 13 | +# |
| 14 | +# This notebook demonstrates how to retrieve attack results based on harm category. While harm category information is not duplicated into the `AttackResultEntries` table, PyRIT provides functions that perform the necessary SQL queries to filter `AttackResults` by harm category. |
| 15 | + |
| 16 | +# %% [markdown] |
| 17 | +# ## Import Seed Prompt Dataset |
| 18 | +# |
| 19 | +# First we import a dataset which has individual prompts with different harm categories as an example. |
| 20 | + |
| 21 | +# %% |
| 22 | +import pathlib |
| 23 | + |
| 24 | +from pyrit.common.initialization import initialize_pyrit |
| 25 | +from pyrit.common.path import DATASETS_PATH |
| 26 | +from pyrit.memory.central_memory import CentralMemory |
| 27 | +from pyrit.models import SeedPromptDataset |
| 28 | + |
| 29 | +initialize_pyrit(memory_db_type="InMemory") |
| 30 | + |
| 31 | +memory = CentralMemory.get_memory_instance() |
| 32 | + |
| 33 | +seed_prompts = SeedPromptDataset.from_yaml_file(pathlib.Path(DATASETS_PATH) / "seed_prompts" / "illegal.prompt") |
| 34 | + |
| 35 | +print(f"Dataset name: {seed_prompts.dataset_name}") |
| 36 | +print(f"Number of prompts in dataset: {len(seed_prompts.prompts)}") |
| 37 | +print() |
| 38 | + |
| 39 | +await memory.add_seed_prompts_to_memory_async(prompts=seed_prompts.prompts, added_by="bolor") # type: ignore |
| 40 | +for i, prompt in enumerate(seed_prompts.prompts): |
| 41 | + print(f"Prompt {i+1}: {prompt.value}, Harm Categories: {prompt.harm_categories}") |
| 42 | + |
| 43 | +# %% [markdown] |
| 44 | +# ## Send to target |
| 45 | +# |
| 46 | +# We use `PromptSendingAttack` to create our `AttackResults` |
| 47 | + |
| 48 | +# %% |
| 49 | +from pyrit.executor.attack import ConsoleAttackResultPrinter, PromptSendingAttack |
| 50 | +from pyrit.prompt_target import OpenAIChatTarget |
| 51 | + |
| 52 | +# Create a real OpenAI target |
| 53 | +target = OpenAIChatTarget() |
| 54 | + |
| 55 | +# Create the attack with the OpenAI target |
| 56 | +attack = PromptSendingAttack(objective_target=target) |
| 57 | + |
| 58 | +# Configure this to load the prompts loaded in the previous step. |
| 59 | +# In the last section, they were in the illegal.prompt file (which has a configured name of "2025_06_pyrit_illegal_example") |
| 60 | +prompt_groups = memory.get_seed_prompt_groups(dataset_name="2025_06_pyrit_illegal_example") |
| 61 | +print(f"Found {len(prompt_groups)} prompt groups for dataset") |
| 62 | + |
| 63 | +for i, group in enumerate(prompt_groups): |
| 64 | + prompt_text = group.prompts[0].value |
| 65 | + |
| 66 | + results = await attack.execute_async(objective=prompt_text, seed_prompt_group=group) # type: ignore |
| 67 | + |
| 68 | + print(f"Attack completed - Conversation ID: {results.conversation_id}") |
| 69 | + await ConsoleAttackResultPrinter().print_conversation_async(result=results) # type: ignore |
| 70 | + |
| 71 | +# %% [markdown] |
| 72 | +# ## Query by harm category |
| 73 | +# Now you can query your attack results by `targeted_harm_category`! |
| 74 | + |
| 75 | +# %% [markdown] |
| 76 | +# ### Single harm category: |
| 77 | +# |
| 78 | +# Here, we by a single harm category (eg shown below is querying for the harm category `['illegal']`) |
| 79 | + |
| 80 | +# %% |
| 81 | +from pyrit.analytics.analyze_results import analyze_results |
| 82 | + |
| 83 | +all_attack_results = memory.get_attack_results() |
| 84 | + |
| 85 | +# Demonstrating how to query attack results by harm category |
| 86 | +print("=== Querying Attack Results by Harm Category ===") |
| 87 | +print() |
| 88 | + |
| 89 | +# First, let's see all attack results to understand what we have |
| 90 | +print(f"Overall attack analytics:") |
| 91 | +print(f"Total attack results in memory: {len(all_attack_results)}") |
| 92 | + |
| 93 | +overall_analytics = analyze_results(list(all_attack_results)) |
| 94 | + |
| 95 | +print(f" Success rate: {overall_analytics['Attack success rate']}") |
| 96 | +print(f" Successes: {overall_analytics['Successes']}") |
| 97 | +print(f" Failures: {overall_analytics['Failures']}") |
| 98 | +print(f" Undetermined: {overall_analytics['Undetermined']}") |
| 99 | +print() |
| 100 | + |
| 101 | +# Example 1: Query for a single harm category |
| 102 | +print("1. Query for single harm category 'illegal':") |
| 103 | +illegal_attacks = memory.get_attack_results(targeted_harm_categories=["illegal"]) |
| 104 | +print(f"\tFound {len(illegal_attacks)} attack results with 'illegal' category") |
| 105 | + |
| 106 | +if illegal_attacks: |
| 107 | + for i, attack_result in enumerate(illegal_attacks): |
| 108 | + print(f"Attack {i+1}: {attack_result.objective}") |
| 109 | + print(f"Conversation ID: {attack_result.conversation_id}") |
| 110 | + print(f"Outcome: {attack_result.outcome}") |
| 111 | + print() |
| 112 | + |
| 113 | +# %% [markdown] |
| 114 | +# ### Multiple harm categories: |
| 115 | + |
| 116 | +# %% |
| 117 | +# Example 2: Query for multiple harm categories |
| 118 | +print("2. Query for multiple harm categories 'illegal' and 'violence':") |
| 119 | +multiple_groups = memory.get_attack_results(targeted_harm_categories=["illegal", "violence"]) |
| 120 | + |
| 121 | +for i, attack_result in enumerate(multiple_groups): |
| 122 | + print(f"Attack {i+1}: {attack_result.objective}...") |
| 123 | + print(f"Conversation ID: {attack_result.conversation_id}") |
| 124 | +print() |
0 commit comments