Skip to content

Commit dd45f23

Browse files
committed
feat: new eds.relation_detector_ffn trainable component
1 parent f36c5b7 commit dd45f23

File tree

14 files changed

+336
-141
lines changed

14 files changed

+336
-141
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
- Added support for multiple loggers (`tensorboard`, `wandb`, `comet_ml`, `aim`, `mlflow`, `clearml`, `dvclive`, `csv`, `json`, `rich`) in `edsnlp.train` via the `logger` parameter. Default is [`json` and `rich`] for backward compatibility.
1515
- Added clickable snippets in the documentation for more registered functions
1616
- New trainable `eds.relation_detector_ffn` component to detect relations between entities. These relations are stored in each entity: `head._.rel[relation_label] = [tail1, tail2, ...]`.
17+
- Load "Status" annotator notes as `status` dict attribute
1718

1819
### Changed
1920

edsnlp/data/converters.py

Lines changed: 92 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -240,87 +240,101 @@ def __init__(
240240

241241
def __call__(self, obj, tokenizer=None):
242242
# tok = get_current_tokenizer() if self.tokenizer is None else self.tokenizer
243-
tok = tokenizer or self.tokenizer or get_current_tokenizer()
244-
doc = tok(obj["text"] or "")
245-
doc._.note_id = obj.get("doc_id", obj.get(FILENAME))
246-
247-
entities = {}
248-
spans = []
249-
250-
for dst in (
251-
*(() if self.span_attributes is None else self.span_attributes.values()),
252-
*self.default_attributes,
253-
):
254-
if not Span.has_extension(dst):
255-
Span.set_extension(dst, default=None)
256-
257-
for ent in obj.get("entities") or ():
258-
fragments = (
259-
[
260-
{
261-
"begin": min(f["begin"] for f in ent["fragments"]),
262-
"end": max(f["end"] for f in ent["fragments"]),
263-
}
264-
]
265-
if not self.split_fragments
266-
else ent["fragments"]
267-
)
268-
for fragment in fragments:
269-
span = doc.char_span(
270-
fragment["begin"],
271-
fragment["end"],
272-
label=ent["label"],
273-
alignment_mode="expand",
274-
)
275-
attributes = (
276-
{a["label"]: a["value"] for a in ent["attributes"]}
277-
if isinstance(ent["attributes"], list)
278-
else ent["attributes"]
243+
note_id = obj.get("doc_id", obj.get(FILENAME))
244+
try:
245+
tok = tokenizer or self.tokenizer or get_current_tokenizer()
246+
doc = tok(obj["text"] or "")
247+
doc._.note_id = note_id
248+
249+
entities = {}
250+
spans = []
251+
252+
for dst in (
253+
*(
254+
()
255+
if self.span_attributes is None
256+
else self.span_attributes.values()
257+
),
258+
*self.default_attributes,
259+
):
260+
if not Span.has_extension(dst):
261+
Span.set_extension(dst, default=None)
262+
263+
for ent in obj.get("entities") or ():
264+
fragments = (
265+
[
266+
{
267+
"begin": min(f["begin"] for f in ent["fragments"]),
268+
"end": max(f["end"] for f in ent["fragments"]),
269+
}
270+
]
271+
if not self.split_fragments
272+
else ent["fragments"]
279273
)
280-
if self.notes_as_span_attribute and ent["notes"]:
281-
ent["attributes"][self.notes_as_span_attribute] = "|".join(
282-
note["value"] for note in ent["notes"]
274+
for fragment in fragments:
275+
span = doc.char_span(
276+
fragment["begin"],
277+
fragment["end"],
278+
label=ent["label"],
279+
alignment_mode="expand",
283280
)
284-
for label, value in attributes.items():
285-
new_name = (
286-
self.span_attributes.get(label, None)
287-
if self.span_attributes is not None
288-
else label
281+
attributes = (
282+
{}
283+
if "attributes" not in ent
284+
else {a["label"]: a["value"] for a in ent["attributes"]}
285+
if isinstance(ent["attributes"], list)
286+
else ent["attributes"]
289287
)
290-
if self.span_attributes is None and not Span.has_extension(
291-
new_name
292-
):
293-
Span.set_extension(new_name, default=None)
294-
295-
if new_name:
296-
value = True if value is None else value
297-
if not self.keep_raw_attribute_values:
298-
value = (
299-
True
300-
if value in ("True", "true")
301-
else False
302-
if value in ("False", "false")
303-
else value
304-
)
305-
span._.set(new_name, value)
306-
307-
entities.setdefault(ent["entity_id"], []).append(span)
308-
spans.append(span)
309-
310-
set_spans(doc, spans, span_setter=self.span_setter)
311-
for attr, value in self.default_attributes.items():
312-
for span in spans:
313-
if span._.get(attr) is None:
314-
span._.set(attr, value)
315-
316-
for relation in obj.get("relations", []):
317-
relation_label = relation["relation_label"]
318-
from_entity_id = relation["from_entity_id"]
319-
to_entity_id = relation["to_entity_id"]
320-
321-
for head in entities[from_entity_id]:
322-
for tail in entities[to_entity_id]:
323-
head._.rel.setdefault(relation_label, set()).add(tail)
288+
if self.notes_as_span_attribute and ent["notes"]:
289+
ent["attributes"][self.notes_as_span_attribute] = "|".join(
290+
note["value"] for note in ent["notes"]
291+
)
292+
for label, value in attributes.items():
293+
new_name = (
294+
self.span_attributes.get(label, None)
295+
if self.span_attributes is not None
296+
else label
297+
)
298+
if self.span_attributes is None and not Span.has_extension(
299+
new_name
300+
):
301+
Span.set_extension(new_name, default=None)
302+
303+
if new_name:
304+
value = True if value is None else value
305+
if not self.keep_raw_attribute_values:
306+
value = (
307+
True
308+
if value in ("True", "true")
309+
else False
310+
if value in ("False", "false")
311+
else value
312+
)
313+
span._.set(new_name, value)
314+
315+
entities.setdefault(ent["entity_id"], []).append(span)
316+
spans.append(span)
317+
318+
set_spans(doc, spans, span_setter=self.span_setter)
319+
for attr, value in self.default_attributes.items():
320+
for span in spans:
321+
if span._.get(attr) is None:
322+
span._.set(attr, value)
323+
324+
for relation in obj.get("relations", []):
325+
relation_label = (
326+
relation["relation_label"]
327+
if "relation_label" in relation
328+
else relation["label"]
329+
)
330+
from_entity_id = relation["from_entity_id"]
331+
to_entity_id = relation["to_entity_id"]
332+
333+
for head in entities.get(from_entity_id, ()):
334+
for tail in entities.get(to_entity_id, ()):
335+
head._.rel.setdefault(relation_label, set()).add(tail)
336+
except Exception:
337+
raise ValueError(f"Error when processing {note_id}")
324338

325339
return doc
326340

edsnlp/data/standoff.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
REGEX_ATTRIBUTE = re.compile(r"^([AM]\d+)\t(.+?) ([TE]\d+)(?: (.+))?$")
3333
REGEX_EVENT = re.compile(r"^(E\d+)\t(.+)$")
3434
REGEX_EVENT_PART = re.compile(r"(\S+):([TE]\d+)")
35+
REGEX_STATUS = re.compile(r"^(#\d+)\tStatus ([^\t]+)\t(.*)$")
3536

3637

3738
class BratParsingError(ValueError):
@@ -71,6 +72,7 @@ def parse_standoff_file(
7172
entities = {}
7273
relations = []
7374
events = {}
75+
doc = {}
7476

7577
with fs.open(txt_path, "r", encoding="utf-8") as f:
7678
text = f.read()
@@ -178,6 +180,11 @@ def parse_standoff_file(
178180
"arguments": arguments,
179181
}
180182
elif line.startswith("#"):
183+
match = REGEX_STATUS.match(line)
184+
if match:
185+
comment = match.group(3)
186+
doc["status"] = comment
187+
continue
181188
match = REGEX_NOTE.match(line)
182189
if match is None:
183190
raise BratParsingError(ann_file, line)
@@ -201,6 +208,7 @@ def parse_standoff_file(
201208
"entities": list(entities.values()),
202209
"relations": relations,
203210
"events": list(events.values()),
211+
**doc,
204212
}
205213

206214

@@ -260,19 +268,19 @@ def dump_standoff_file(
260268
)
261269
attribute_idx += 1
262270

263-
# fmt: off
264-
if "relations" in doc:
265-
for i, relation in enumerate(doc["relations"]):
266-
entity_from = entities_ids[relation["from_entity_id"]]
267-
entity_to = entities_ids[relation["to_entity_id"]]
268-
print(
269-
"R{}\t{} Arg1:{} Arg2:{}\t".format(
270-
i + 1, str(relation["label"]), entity_from,
271-
entity_to
272-
),
273-
file=f,
274-
)
275-
# fmt: on
271+
# fmt: off
272+
if "relations" in doc:
273+
for i, relation in enumerate(doc["relations"]):
274+
entity_from = entities_ids[relation["from_entity_id"]]
275+
entity_to = entities_ids[relation["to_entity_id"]]
276+
print(
277+
"R{}\t{} Arg1:{} Arg2:{}\t".format(
278+
i + 1, str(relation["label"]), entity_from,
279+
entity_to
280+
),
281+
file=f,
282+
)
283+
# fmt: on
276284

277285

278286
class StandoffReader(FileBasedReader):

edsnlp/metrics/relations.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def relations_scorer(
5151
head_getter = candidate["head"]
5252
tail_getter = candidate["tail"]
5353
labels = candidate["labels"]
54+
symmetric = candidate.get("symmetric") or False
55+
label_filter = candidate.get("label_filter")
5456
for eg_idx, eg in enumerate(examples):
5557
pred_heads = [
5658
((h.start, h.end, h.label_), h)
@@ -61,9 +63,21 @@ def relations_scorer(
6163
for t in get_spans(eg.predicted, tail_getter)
6264
]
6365
for (h_key, head), (t_key, tail) in product(pred_heads, pred_tails):
66+
if (
67+
label_filter is not None
68+
and head.label_ not in label_filter
69+
or tail.label_ not in label_filter
70+
):
71+
continue
6472
total_pred_count += 1
6573
for label in labels:
66-
if tail in head._.rel.get(label, ()):
74+
if (
75+
tail in head._.rel.get(label, ())
76+
or symmetric
77+
and head in tail._.rel.get(label, ())
78+
):
79+
if symmetric and h_key > t_key:
80+
h_key, t_key = t_key, h_key
6781
annotations[label][0].add((eg_idx, h_key, t_key, label))
6882
annotations[micro_key][0].add((eg_idx, h_key, t_key, label))
6983

@@ -78,7 +92,13 @@ def relations_scorer(
7892
for (h_key, head), (t_key, tail) in product(gold_heads, gold_tails):
7993
total_gold_count += 1
8094
for label in labels:
81-
if tail in head._.rel.get(label, ()):
95+
if (
96+
tail in head._.rel.get(label, ())
97+
or symmetric
98+
and head in tail._.rel.get(label, ())
99+
):
100+
if symmetric and h_key > t_key:
101+
h_key, t_key = t_key, h_key
82102
annotations[label][1].add((eg_idx, h_key, t_key, label))
83103
annotations[micro_key][1].add((eg_idx, h_key, t_key, label))
84104

edsnlp/pipes/base.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,6 @@ def qualifiers(self, value): # pragma: no cover
208208

209209

210210
class BaseRelationDetectorComponent(BaseComponent, abc.ABC):
211-
head_getter: SpanGetter
212-
tail_getter: SpanGetter
213-
labels: List[str]
214-
215211
def __init__(
216212
self,
217213
nlp: PipelineProtocol = None,
@@ -226,6 +222,13 @@ def __init__(
226222
"head": validate_span_getter(candidate["head"]),
227223
"tail": validate_span_getter(candidate["tail"]),
228224
"labels": candidate["labels"],
225+
"label_filter": {
226+
head: set(tail_labels)
227+
for head, tail_labels in candidate["label_filter"].items()
228+
}
229+
if candidate.get("label_filter")
230+
else None,
231+
"symmetric": candidate.get("symmetric") or False,
229232
}
230233
for candidate in candidate_getter
231234
]

edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ def forward(self, batch: SpanPoolerBatchInput) -> SpanPoolerBatchOutput:
210210
"embeddings": batch["begins"].with_data(span_embeds),
211211
}
212212

213-
embeds = self.embedding(batch["embedding"])["embeddings"]
213+
embeds = self.embedding(batch["embedding"])["embeddings"].refold(
214+
["context", "word"]
215+
)
214216
_, n_words, dim = embeds.shape
215217
device = embeds.device
216218

0 commit comments

Comments
 (0)