Skip to content

Commit a98894e

Browse files
Temporal agents cookbook (#1970)
1 parent 45dd652 commit a98894e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+7907
-0
lines changed

examples/partners/temporal_agents_with_knowledge_graphs/Appendix.ipynb

Lines changed: 671 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
"""Reusable functions for the cookbook."""
2+
3+
import sqlite3
4+
import networkx as nx
5+
from typing import Any
6+
from datasets import load_dataset
7+
8+
from db_interface import get_all_triplets
9+
10+
11+
def load_db_from_hf(db_path: str = "temporal_graph.db", hf_dataset_name: str = "TomoroAI/temporal_cookbook_db") -> sqlite3.Connection:
12+
"""Load the pre-processed database from HuggingFace."""
13+
conn = sqlite3.connect(db_path)
14+
table_names = [
15+
"transcripts",
16+
"chunks",
17+
"events",
18+
"triplets",
19+
"entities",
20+
]
21+
22+
for table in table_names:
23+
print(f"Loading {table}...")
24+
ds = load_dataset(hf_dataset_name, name=table, split="train")
25+
df = ds.to_pandas()
26+
df.to_sql(table, conn, if_exists="replace", index=False)
27+
28+
conn.commit()
29+
print("✅ All tables written to SQLite.")
30+
31+
return conn
32+
33+
def build_graph(
34+
conn: sqlite3.Connection,
35+
*,
36+
nodes_as_names: bool = False
37+
) -> nx.MultiDiGraph:
38+
"""Build graph using canonical entity IDs and names."""
39+
graph = nx.MultiDiGraph()
40+
41+
# Always load canonical mappings
42+
entity_to_canonical, canonical_names = _load_entity_maps(conn)
43+
event_temporal_map = _load_event_temporal(conn)
44+
45+
for t in get_all_triplets(conn):
46+
if not t["subject_id"]:
47+
continue
48+
49+
event_attrs = event_temporal_map.get(t["event_id"])
50+
_add_triplet_edge(
51+
graph,
52+
t,
53+
entity_to_canonical,
54+
canonical_names,
55+
event_attrs,
56+
nodes_as_names,
57+
)
58+
59+
return graph
60+
61+
def _load_entity_maps(conn: sqlite3.Connection) -> tuple[dict[bytes, bytes], dict[bytes, str]]:
62+
"""
63+
Return mappings for canonical entities:
64+
• entity_to_canonical: maps entity ID → canonical ID (using resolved_id)
65+
• canonical_names: maps canonical ID → canonical name.
66+
"""
67+
cur = conn.cursor()
68+
69+
# Get all entities with their resolved IDs
70+
cur.execute("""
71+
SELECT id, name, resolved_id
72+
FROM entities
73+
""")
74+
75+
entity_to_canonical: dict[bytes, bytes] = {}
76+
canonical_names: dict[bytes, str] = {}
77+
78+
for row in cur.fetchall():
79+
entity_id = row[0]
80+
name = row[1]
81+
resolved_id = row[2]
82+
83+
if resolved_id:
84+
# If entity has a resolved_id, map to that
85+
entity_to_canonical[entity_id] = resolved_id
86+
# Store name of the canonical entity
87+
canonical_names[resolved_id] = name
88+
else:
89+
# If no resolved_id, entity is its own canonical version
90+
entity_to_canonical[entity_id] = entity_id
91+
canonical_names[entity_id] = name
92+
93+
return entity_to_canonical, canonical_names
94+
95+
def _load_event_temporal(conn: sqlite3.Connection) -> dict[bytes, dict[str, Any]]:
96+
"""
97+
Read the `events` table once and build a mapping
98+
event_id (bytes) → dict of temporal / descriptive attributes.
99+
Only the columns that are useful on the graph edges are pulled;
100+
extend this list freely if you need more.
101+
"""
102+
cur = conn.cursor()
103+
cur.execute("""
104+
SELECT id,
105+
statement,
106+
statement_type,
107+
temporal_type,
108+
created_at,
109+
valid_at,
110+
expired_at,
111+
invalid_at,
112+
invalidated_by
113+
FROM events
114+
""")
115+
event_map: dict[bytes, dict[str, Any]] = {}
116+
for (
117+
eid,
118+
statement,
119+
statement_type,
120+
temporal_type,
121+
created_at,
122+
valid_at,
123+
expired_at,
124+
invalid_at,
125+
invalidated_by,
126+
) in cur.fetchall():
127+
event_map[eid] = {
128+
"statement": statement,
129+
"statement_type": statement_type,
130+
"temporal_type": temporal_type,
131+
"created_at": created_at,
132+
"valid_at": valid_at,
133+
"expired_at": expired_at,
134+
"invalid_at": invalid_at,
135+
"invalidated_by": invalidated_by,
136+
}
137+
return event_map
138+
139+
140+
def _add_triplet_edge(
141+
graph: nx.MultiDiGraph, t: dict,
142+
entity_to_canonical: dict[bytes, bytes],
143+
canonical_names: dict[bytes, str],
144+
event_attrs: dict[str, Any] | None = None,
145+
use_names: bool = False,
146+
) -> None:
147+
"""Add one edge using canonical IDs and names."""
148+
subj_id = t["subject_id"]
149+
obj_id = t["object_id"]
150+
151+
if subj_id is None:
152+
return
153+
154+
# Get canonical IDs
155+
canonical_subj = entity_to_canonical.get(subj_id, subj_id)
156+
canonical_obj = entity_to_canonical.get(obj_id, obj_id) if obj_id else None
157+
158+
# Get canonical names
159+
subj_name = canonical_names.get(canonical_subj, t["subject_name"]) if canonical_subj is not None else t["subject_name"]
160+
obj_name = canonical_names.get(canonical_obj, t["object_name"]) if canonical_obj is not None else t["object_name"]
161+
162+
subj_node = subj_name if use_names else canonical_subj
163+
obj_node = obj_name if use_names else canonical_obj
164+
165+
# Add nodes with canonical names
166+
graph.add_node(
167+
subj_node,
168+
canonical_id=canonical_subj,
169+
name=subj_name,
170+
)
171+
172+
# Core edge attributes (triplet-specific)
173+
edge_attrs: dict[str, Any] = {
174+
"predicate": t["predicate"],
175+
"triplet_id": t["id"],
176+
"event_id": t["event_id"],
177+
"value": t["value"],
178+
"canonical_subject_name": subj_name,
179+
"canonical_object_name": obj_name,
180+
}
181+
182+
# Merge in temporal data, if we have it
183+
if event_attrs:
184+
edge_attrs.update(event_attrs)
185+
186+
if canonical_obj is None:
187+
# Handle self-loops for null objects
188+
graph.add_edge(
189+
subj_node, subj_node,
190+
key=t["predicate"],
191+
**edge_attrs,
192+
literal_object=t["object_name"],
193+
)
194+
else:
195+
graph.add_node(
196+
obj_node,
197+
canonical_id=canonical_obj,
198+
name=obj_name,
199+
)
200+
graph.add_edge(
201+
subj_node, obj_node,
202+
key=t["predicate"],
203+
**edge_attrs,
204+
)

0 commit comments

Comments
 (0)