-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathextract_entities.py
More file actions
93 lines (77 loc) · 3.36 KB
/
extract_entities.py
File metadata and controls
93 lines (77 loc) · 3.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import spacy
import re
from transformers import BertTokenizer, BertForTokenClassification
import torch
from rapidfuzz import fuzz
# Load spaCy model for quick prototyping
nlp = spacy.load("en_core_web_sm")
# Pre-trained BERT model for NER
tokenizer = BertTokenizer.from_pretrained("dslim/bert-base-NER")
model = BertForTokenClassification.from_pretrained("dslim/bert-base-NER")
def extract_entities_spacy(text):
"""Extract entities using spaCy (quick prototype)."""
doc = nlp(text)
entities = {"jurisdiction": None, "serial_numbers": []}
# Extract jurisdictions (GPE = geopolitical entity)
for ent in doc.ents:
if ent.label_ == "GPE":
entities["jurisdiction"] = ent.text
# Extract serial numbers with regex
serial_patterns = [
r"\bTM[-\s]?\d{5,}\b", # e.g., TM-12345, TM12345
r"\bAppNo[-\s]?\d{6,}\b", # e.g., AppNo: 987654
r"\bRegNo[-\s]?\d{6,}\b", # e.g., RegNo 1234567
r"\b\d{5,}[-\s]?\d{4,}\b" # General case numbers
]
for pattern in serial_patterns:
matches = re.findall(pattern, text, re.IGNORECASE)
entities["serial_numbers"].extend(matches)
return entities
def extract_entities_bert(text):
"""Extract entities using PyTorch BERT NER."""
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
predictions = torch.argmax(outputs.logits, dim=2)[0]
labels = [model.config.id2label[p.item()] for p in predictions]
entities = {"jurisdiction": None, "serial_numbers": []}
current_entity = ""
for token, label in zip(tokens, labels):
if label.startswith("B-LOC"): # Location (potential jurisdiction)
current_entity = token
elif label.startswith("I-LOC") and current_entity:
current_entity += " " + token
elif current_entity and not label.startswith("I-"):
entities["jurisdiction"] = current_entity.replace("##", "")
current_entity = ""
# Use regex for serial numbers (BERT NER not trained for custom serials)
serial_patterns = [
r"\bTM[-\s]?\d{5,}\b",
r"\bAppNo[-\s]?\d{6,}\b",
r"\bRegNo[-\s]?\d{6,}\b",
r"\b\d{5,}[-\s]?\d{4,}\b"
]
entities["serial_numbers"] = re.findall("|".join(serial_patterns), text, re.IGNORECASE)
return entities
def match_serial_numbers(extracted_numbers, database_numbers):
"""Fuzzy match serial numbers to database records."""
matches = []
for num in extracted_numbers:
for db_num in database_numbers:
if fuzz.ratio(num, db_num) > 85: # Threshold for similarity
matches.append((num, db_num))
return matches
if __name__ == "__main__":
# Test entity extraction
sample_text = "United States Trademark Certificate TM-12345, RegNo 987654"
database = ["TM-12345", "RegNo 987654", "AppNo 555555"]
# spaCy
spacy_entities = extract_entities_spacy(sample_text)
print("spaCy Entities:", spacy_entities)
# BERT
bert_entities = extract_entities_bert(sample_text)
print("BERT Entities:", bert_entities)
# Fuzzy matching
matches = match_serial_numbers(spacy_entities["serial_numbers"], database)
print("Matched Serial Numbers:", matches)