Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions detection_rules/atlas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
# or more contributor license agreements. Licensed under the Elastic License
# 2.0; you may not use this file except in compliance with the Elastic License
# 2.0.

"""Mitre ATLAS info."""

from collections import OrderedDict
from pathlib import Path
from typing import Any

import requests
import yaml
from semver import Version

from .utils import cached, clear_caches, get_etc_path

ATLAS_FILE = get_etc_path(["ATLAS.yaml"])

# Maps tactic name to tactic ID (e.g., "Collection" -> "AML.TA0009")
tactics_map: dict[str, str] = {}
technique_lookup: dict[str, dict[str, Any]] = {}
matrix: dict[str, list[str]] = {} # Maps tactic name to list of technique IDs


def get_atlas_file_path() -> Path:
"""Get the path to the ATLAS YAML file."""
if not ATLAS_FILE.exists():
# Try to download it if it doesn't exist
_ = download_atlas_data()
return ATLAS_FILE


def download_atlas_data(save: bool = True) -> dict[str, Any] | None:
"""Download ATLAS data from MITRE."""
url = "https://raw.githubusercontent.com/mitre-atlas/atlas-data/main/dist/ATLAS.yaml"
r = requests.get(url, timeout=30)
r.raise_for_status()
atlas_data = yaml.safe_load(r.text)

if save:
_ = ATLAS_FILE.write_text(r.text)
print(f"Downloaded ATLAS data to {ATLAS_FILE}")

return atlas_data


@cached
def load_atlas_yaml() -> dict[str, Any]:
"""Load ATLAS data from YAML file."""
atlas_file = get_atlas_file_path()
return yaml.safe_load(atlas_file.read_text())


atlas = load_atlas_yaml()

# Extract version
CURRENT_ATLAS_VERSION = atlas.get("version", "unknown")

# Process the ATLAS matrix
if "matrices" in atlas and len(atlas["matrices"]) > 0:
matrix_data = atlas["matrices"][0] # Use the first matrix (usually "ATLAS Matrix")

# Build tactics map
if "tactics" in matrix_data:
for tactic in matrix_data["tactics"]:
tactic_id = tactic["id"]
tactic_name = tactic["name"]
tactics_map[tactic_name] = tactic_id

# Build technique lookup and matrix
if "techniques" in matrix_data:
for technique in matrix_data["techniques"]:
technique_id = technique["id"]
technique_name = technique["name"]
technique_tactics = technique.get("tactics", [])

# Store technique info
technique_lookup[technique_id] = {
"name": technique_name,
"id": technique_id,
"tactics": technique_tactics,
}

# Build matrix: map tactic IDs to technique IDs
for tech_tactic_id in technique_tactics:
# Find tactic name from ID
tech_tactic_name = next((name for name, tid in tactics_map.items() if tid == tech_tactic_id), None)
if tech_tactic_name:
if tech_tactic_name not in matrix:
matrix[tech_tactic_name] = []
if technique_id not in matrix[tech_tactic_name]:
matrix[tech_tactic_name].append(technique_id)

# Sort matrix values
for val in matrix.values():
val.sort(key=lambda tid: technique_lookup.get(tid, {}).get("name", "").lower())

technique_lookup = OrderedDict(sorted(technique_lookup.items()))
techniques = sorted({v["name"] for _, v in technique_lookup.items()})
technique_id_list = [t for t in technique_lookup if "." not in t]
sub_technique_id_list = [t for t in technique_lookup if "." in t]
tactics = list(tactics_map)


def refresh_atlas_data(save: bool = True) -> dict[str, Any] | None:
"""Refresh ATLAS data from MITRE."""
atlas_file = get_atlas_file_path()
current_version_str = CURRENT_ATLAS_VERSION

try:
current_version = Version.parse(current_version_str, optional_minor_and_patch=True)
except (ValueError, TypeError):
# If version parsing fails, download anyway
current_version = Version.parse("0.0.0", optional_minor_and_patch=True)

# Get latest version from GitHub
r = requests.get("https://api.github.com/repos/mitre-atlas/atlas-data/tags", timeout=30)
r.raise_for_status()
releases = r.json()
if not releases:
print("No releases found")
return None

# Find latest version (tags might be like "v5.1.0" or "5.1.0")
latest_release = None
latest_version = current_version
for release in releases:
tag_name = release["name"].lstrip("v")
try:
ver = Version.parse(tag_name, optional_minor_and_patch=True)
if ver > latest_version:
latest_version = ver
latest_release = release
except (ValueError, TypeError):
continue

if latest_release is None:
print(f"No versions newer than the current detected: {current_version_str}")
return None

download = f"https://raw.githubusercontent.com/mitre-atlas/atlas-data/{latest_release['name']}/dist/ATLAS.yaml"
r = requests.get(download, timeout=30)
r.raise_for_status()
atlas_data = yaml.safe_load(r.text)

if save:
_ = atlas_file.write_text(r.text)
print(f"Replaced file: {atlas_file} with version {latest_version}")

# Clear cache to reload
clear_caches()

return atlas_data


def build_threat_map_entry(tactic_name: str, *technique_ids: str) -> dict[str, Any]:
"""Build rule threat map from ATLAS technique IDs."""
url_base = "https://atlas.mitre.org/{type}/{id}/"
tactic_id = tactics_map.get(tactic_name)
if not tactic_id:
raise ValueError(f"Unknown ATLAS tactic: {tactic_name}")

tech_entries: dict[str, Any] = {}

def make_entry(_id: str) -> dict[str, Any]:
tech_info = technique_lookup.get(_id)
if not tech_info:
raise ValueError(f"Unknown ATLAS technique ID: {_id}")
return {
"id": _id,
"name": tech_info["name"],
"reference": url_base.format(type="techniques", id=_id.replace(".", "/")),
}

for tid in technique_ids:
if tid not in technique_lookup:
raise ValueError(f"Unknown ATLAS technique ID: {tid}")

tech_info = technique_lookup[tid]
tech_tactic_ids = tech_info.get("tactics", [])
if tactic_id not in tech_tactic_ids:
raise ValueError(f"ATLAS technique ID: {tid} does not fall under tactic: {tactic_name}")

# Handle sub-techniques (e.g., AML.T0000.000)
if "." in tid and tid.count(".") > 1:
# This is a sub-technique
parts = tid.rsplit(".", 1)
parent_technique = parts[0]
tech_entries.setdefault(parent_technique, make_entry(parent_technique))
tech_entries[parent_technique].setdefault("subtechnique", []).append(make_entry(tid))
else:
tech_entries.setdefault(tid, make_entry(tid))

entry: dict[str, Any] = {
"framework": "MITRE ATLAS",
"tactic": {
"id": tactic_id,
"name": tactic_name,
"reference": url_base.format(type="tactics", id=tactic_id),
},
}

if tech_entries:
entry["technique"] = sorted(tech_entries.values(), key=lambda x: x["id"])

return entry
Loading
Loading