Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 133 additions & 1 deletion lib/sycamore/sycamore/tests/unit/transforms/test_merge_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

import sycamore
from sycamore.data import Document, Table
from sycamore.transforms.merge_elements import GreedyTextElementMerger, Merge, GreedySectionMerger
from sycamore.transforms.merge_elements import (
GreedyTextElementMerger,
Merge,
GreedySectionMerger,
HeaderAugmenterMerger,
)

from sycamore.functions.tokenizer import HuggingFaceTokenizer
from sycamore.plan_nodes import Node

Expand Down Expand Up @@ -386,3 +392,129 @@ def test_merge_empty_text_works(self):
merger = GreedySectionMerger(tokenizer, 1200, merge_across_pages=False)
new_doc = merger.merge_elements(self.doc2)
assert new_doc.elements[0].text_representation is not None


class TestHeaderAugmenterMerger:

doc = Document(
{
"doc_id": "doc_id",
"type": "pdf",
"text_representation": "text",
"binary_representation": None,
"parent_id": None,
"properties": {"path": "/docs/foo.txt", "title": "bar"},
"elements": [
{
"type": "Section-header",
"text_representation": "section1",
"properties": {"filetype": "text/plain", "page_number": 1},
},
{
"type": "Section-header",
"text_representation": "section1.1",
"properties": {"filetype": "text/plain", "page_number": 1},
},
{
"type": "Text",
"text_representation": "text1 on page 1",
"properties": {"filetype": "text/plain", "page_number": 1},
},
{
"type": "Table",
"text_representation": "table1 on page 2",
"properties": {"filetype": "text/plain", "page_number": 2},
},
{
"type": "Title",
"text_representation": "title1 on page 2",
"properties": {"filetype": "text/plain", "page_number": 2},
},
{
"type": "Section-header",
"text_representation": "section2 on page 2",
"properties": {"filetype": "text/plain", "page_number": 2},
},
{
"type": "Text",
"text_representation": "text2 on page 2",
"properties": {"filetype": "text/plain", "page_number": 2},
},
{
"type": "Text",
"text_representation": "text3 on page 2",
"properties": {"filetype": "text/plain", "page_number": 2},
},
{
"type": "Text",
"text_representation": "text4 on page 3",
"properties": {"filetype": "text/plain", "page_number": 3},
},
{},
],
}
)

def test_merge_elements(self):
tokenizer = HuggingFaceTokenizer("sentence-transformers/all-MiniLM-L6-v2")
merger = HeaderAugmenterMerger(tokenizer, 1200, merge_across_pages=True)

new_doc = merger.merge_elements(self.doc)
assert len(new_doc.elements) == 4
e = new_doc.elements[0]
assert e.type == "Text"
assert e.text_representation == ("section1\nsection1.1\ntext1 on page 1")
assert e.properties == {
"filetype": "text/plain",
"page_number": 1,
}
assert e["_header"] == "section1\nsection1.1"

e = new_doc.elements[1]
assert e.type == "table"
assert e.text_representation == ("section1\nsection1.1\ntable1 on page 2")
assert e.properties == {
"filetype": "text/plain",
"page_number": 2,
"title": None,
"columns": None,
"rows": None,
}
assert e["_header"] == "section1\nsection1.1"

e = new_doc.elements[2]
assert e.type == "Text"
assert e.text_representation == (
"title1 on page 2\nsection2 on page 2\ntext2 on page 2\ntext3 on page 2\ntext4 on page 3"
)
assert e.properties == {
"filetype": "text/plain",
"page_number": 2,
"page_numbers": [2, 3],
}
assert e["_header"] == "title1 on page 2\nsection2 on page 2"

def test_merge_elements_via_execute(self, mocker):
node = mocker.Mock(spec=Node)
input_dataset = ray.data.from_items([{"doc": self.doc.serialize()}])
execute = mocker.patch.object(node, "execute")
execute.return_value = input_dataset
tokenizer = HuggingFaceTokenizer("sentence-transformers/all-MiniLM-L6-v2")
merger = HeaderAugmenterMerger(tokenizer, 120, merge_across_pages=True)
merge = Merge(node, merger)
output_dataset = merge.execute()
output_dataset.show()

def test_docset_augmented(self):
ray.shutdown()

context = sycamore.init()
tokenizer = HuggingFaceTokenizer("sentence-transformers/all-MiniLM-L6-v2")
context.read.document([self.doc]).merge(HeaderAugmenterMerger(tokenizer, 120, merge_across_pages=True)).show()

# Verify that GreedyTextElementMerger can't be an argument for map.
# We may want to change this in the future.
with pytest.raises(ValueError):
sycamore.init().read.document([self.doc]).map(
HeaderAugmenterMerger(tokenizer, 120, merge_across_pages=True)
)
194 changes: 194 additions & 0 deletions lib/sycamore/sycamore/transforms/merge_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,200 @@ def postprocess_element(self, elem: Element) -> Element:
return elem


class HeaderAugmenterMerger(ElementMerger):
"""
The ``HeaderAugmenterMerger`` groups together different elements in a Document and enhances the text
representation of the elements by adding the preceeding section-header/title.

- It merges certain elements ("Text", "List-item", "Caption", "Footnote", "Formula", "Page-footer", "Page-header").
- It merges consectuive ("Section-header", "Title") elements.
- It adds the preceeding section-header/title to the text representation of the elements (including tables/images).
"""

def __init__(self, tokenizer: Tokenizer, max_tokens: int, merge_across_pages: bool = True):
self.tokenizer = tokenizer
self.max_tokens = max_tokens
self.merge_across_pages = merge_across_pages

def preprocess_element(self, element: Element) -> Element:
if element.type == "Image" and "summary" in element.properties and "summary" in element.properties["summary"]:
element.data["token_count"] = len(self.tokenizer.tokenize(element.properties["summary"]["summary"] or ""))
else:
element.data["token_count"] = len(self.tokenizer.tokenize(element.text_representation or ""))
return element

def postprocess_element(self, element: Element) -> Element:
del element.data["token_count"]
return element

def merge_elements(self, document: Document) -> Document:
"""Use self.should_merge and self.merge to greedily merge consecutive elements.
If the next element should be merged into the last 'accumulation' element, merge it.

Args:
document (Document): A document with elements to be merged.

Returns:
Document: The same document, with its elements merged
"""
if len(document.elements) < 2:
return document

for element in document.elements:
if element.type in ["Section-header", "Title"]:
element.data["_header"] = element.text_representation
Comment thread
MarkLindblad marked this conversation as resolved.

to_merge = [self.preprocess_element(e) for e in document.elements]
new_elements = [to_merge[0]]
for element in to_merge[1:]:
if self.should_merge(new_elements[-1], element):
new_elements[-1] = self.merge(new_elements[-1], element)
else:
new_elements.append(element)
document.elements = [
self.postprocess_element(e) for e in new_elements if e.type not in ["Section-header", "Title"]
]
return document

def should_merge(self, element1: Element, element2: Element) -> bool:
# deal with empty elements
if (
DocumentPropertyTypes.PAGE_NUMBER not in element1.properties
or DocumentPropertyTypes.PAGE_NUMBER not in element2.properties
or element1.type is None
or element2.type is None
):
return False

# Conditionally prevent merging across pages
if (
not self.merge_across_pages
and element1.properties[DocumentPropertyTypes.PAGE_NUMBER]
!= element2.properties[DocumentPropertyTypes.PAGE_NUMBER]
):
return False

if element1.data["token_count"] + 1 + element2.data["token_count"] > self.max_tokens and element2.type not in [
"Section-header",
"Title",
]:
# Add header to next element
element2["_header"] = element1["_header"]
if element1.data["_header"]:
if element2.text_representation:
element2.text_representation = element1.data["_header"] + "\n" + element2.text_representation
else:
element2.text_representation = element1.data["_header"]
return False
Comment thread
dhruvkaliraman7 marked this conversation as resolved.

# Merge consecutive section headers/titles and save as a section-header element
if element1.type in ["Section-header", "Title"] and element2.type in ["Section-header", "Title"]:
return True

# MERGE adjacent 'text' elements
text_like = {"Text", "List-item", "Caption", "Footnote", "Formula", "Page-footer", "Page-header", "Section"}
if (
(element1 is not None)
and (element1.type in text_like)
and (element2 is not None)
and (element2.type in text_like)
):
return True

# Add header to next element (images, tables)
if element2.type not in ["Section-header", "Title"]:
Comment thread
dhruvkaliraman7 marked this conversation as resolved.
element2.data["_header"] = element1.data["_header"]
if element2.text_representation:
if element2.data["_header"]:
element2.text_representation = element2.data["_header"] + "\n" + element2.text_representation
else:
element2.text_representation = element2.data["_header"]
return False
Comment thread
MarkLindblad marked this conversation as resolved.

def merge(self, elt1: Element, elt2: Element) -> Element:
"""Merge two elements; the new element's fields will be set as:
- type: "Section-header", "Text"
- binary_representation: elt1.binary_representation + elt2.binary_representation
- text_representation: elt1.text_representation + elt2.text_representation
- bbox: the minimal bbox that contains both elt1's and elt2's bboxes
- properties: elt1's properties + any of elt2's properties that are not in elt1
note: if elt1 and elt2 have different values for the same property, we take elt1's value
note: if any input field is None we take the other element's field without merge logic

Args:
element1 (Element): the first element (numbers of tokens in it is stored by `preprocess_element`
as element1["token_count"])
element2 (Element): the second element (numbers of tokens in it is stored by `preprocess_element`
as element2["token_count"])

Returns:
Element: a new merged element from the inputs (and number of tokens in it)
"""

tok1 = elt1.data["token_count"]
tok2 = elt2.data["token_count"]
new_elt = Element()

if elt1.type in ["Section-header", "Title"] and elt2.type in ["Section-header", "Title"]:
new_elt.type = "Section-header"
else:
new_elt.type = "Text"

# Merge binary representations by concatenation
if elt1.binary_representation is None or elt2.binary_representation is None:
new_elt.binary_representation = elt1.binary_representation or elt2.binary_representation
else:
new_elt.binary_representation = elt1.binary_representation + elt2.binary_representation

# Merge text representations by concatenation with a newline
new_elt_text_representation = "\n".join(filter(None, [elt1.text_representation, elt2.text_representation]))
Comment thread
dhruvkaliraman7 marked this conversation as resolved.
new_elt.text_representation = new_elt_text_representation if new_elt_text_representation else None
if elt1.text_representation is None or elt2.text_representation is None:
new_elt.data["token_count"] = tok1 + tok2
else:
new_elt.data["token_count"] = tok1 + 1 + tok2
Comment thread
dhruvkaliraman7 marked this conversation as resolved.

# Merge bbox by taking the coords that make the largest box
if elt1.bbox is None and elt2.bbox is None:
pass
elif elt1.bbox is None or elt2.bbox is None:
new_elt.bbox = elt1.bbox or elt2.bbox
else:
# TO-DO: Make bbox work across pages
new_elt.bbox = BoundingBox(
min(elt1.bbox.x1, elt2.bbox.x1),
min(elt1.bbox.y1, elt2.bbox.y1),
max(elt1.bbox.x2, elt2.bbox.x2),
max(elt1.bbox.y2, elt2.bbox.y2),
)

# Merge properties by taking the union of the keys
properties = new_elt.properties
for k, v in elt1.properties.items():
properties[k] = v
if k == DocumentPropertyTypes.PAGE_NUMBER:
properties["page_numbers"] = properties.get("page_numbers", list())
properties["page_numbers"] = list(set(properties["page_numbers"] + [v]))
for k, v in elt2.properties.items():
if properties.get(k) is None:
properties[k] = v
# if a page number exists, add it to the set of page numbers for this new element
if k == DocumentPropertyTypes.PAGE_NUMBER:
properties["page_numbers"] = properties.get("page_numbers", list())
properties["page_numbers"] = list(set(properties["page_numbers"] + [v]))
if elt1.type in ["Section-header", "Title"] and elt2.type in ["Section-header", "Title"]:
if elt1.data["_header"] is None or elt2.data["_header"] is None:
new_elt.data["_header"] = elt1.data["_header"] or elt2.data["_header"]
else:
new_elt.data["_header"] = elt1.data["_header"] + "\n" + elt2.data["_header"]
else:

new_elt.data["_header"] = elt1.data["_header"]
new_elt.properties = properties

return new_elt


class Merge(SingleThreadUser, NonGPUUser, Map):
"""
Merge Elements into fewer large elements
Expand Down
16 changes: 13 additions & 3 deletions lib/sycamore/sycamore/transforms/split_elements.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional


import logging
from sycamore.data import Document, Element
from sycamore.functions.tokenizer import Tokenizer
from sycamore.plan_nodes import Node, SingleThreadUser, NonGPUUser
from sycamore.transforms.map import Map
from sycamore.utils.time_trace import timetrace

logger = logging.getLogger(__name__)


class SplitElements(SingleThreadUser, NonGPUUser, Map):
"""
Expand Down Expand Up @@ -34,12 +35,18 @@ def __init__(self, child: Node, tokenizer: Tokenizer, maximum: int, **kwargs):
def split_doc(parent: Document, tokenizer: Tokenizer, max: int) -> Document:
result = []
for elem in parent.elements:
# Ensure the _header does not take up more than a third of the tokens
# Also avoid max resursive depth error
if elem.get("_header") and len(tokenizer.tokenize(elem["_header"])) / max > 0.33:
logger.warning(f"Token limit exceeded, dropping _header: {elem['_header']}")
del elem["_header"]
result.extend(SplitElements.split_one(elem, tokenizer, max))
parent.elements = result
return parent

@staticmethod
def split_one(elem: Element, tokenizer: Tokenizer, max: int) -> list[Element]:

txt = elem.text_representation
if not txt:
return [elem]
Expand Down Expand Up @@ -96,7 +103,10 @@ def split_one(elem: Element, tokenizer: Tokenizer, max: int) -> list[Element]:
ment = elem.copy()
elem.text_representation = one
elem.binary_representation = bytes(one, "utf-8")
ment.text_representation = two
if elem.get("_header"):
ment.text_representation = ment["_header"] + "\n" + two
else:
ment.text_representation = two
ment.binary_representation = bytes(two, "utf-8")
aa = SplitElements.split_one(elem, tokenizer, max)
bb = SplitElements.split_one(ment, tokenizer, max)
Expand Down