import copy
import html
import logging
import re
import unicodedata
import xml.etree.ElementTree as ET
from re import Pattern
from typing import Any, Optional, Sequence
logger = logging.getLogger(__name__)
NAMESPACE = "http://www.w3.org/2000/svg"
XHTML_NAMESPACE = "http://www.w3.org/1999/xhtml"
ILLEGAL_XML_RE: Pattern[str] = re.compile(
"[\x00-\x08\x0b-\x1f\x7f-\x9f\ud800-\udfff\ufdd0-\ufddf\ufffe-\uffff]"
)
DEFAULT_NUMBER_DIGITS = 2
[docs]
def safe_utf8(text: str) -> str:
"""Remove illegal and problematic XML characters from text.
This function filters out characters that are illegal or problematic in XML
documents by replacing them with spaces. The filtering is intentionally
conservative, removing more characters than strictly required by the XML 1.0
specification.
Characters filtered (replaced with space):
- C0 controls (0x00-0x08, 0x0B-0x1F): NULL, control characters
Exception: TAB (0x09), LF (0x0A) are preserved
- C1 controls (0x7F-0x9F): DEL and C1 control block including NEL (0x85)
- UTF-16 surrogates (0xD800-0xDFFF): Illegal in UTF-8/XML
- Non-characters (0xFDD0-0xFDDF, 0xFFFE-0xFFFF): Illegal in XML
Characters preserved:
- TAB (0x09): Horizontal spacing
- LF (0x0A): Line breaks
- Normal printable characters (0x20 and above, excluding filtered ranges)
Note on CR (0x0D):
CR (Carriage Return) is filtered despite being technically legal in XML 1.0.
This is an intentional conservative design choice because:
1. XML processors normalize CR to LF per XML 1.0 spec (section 2.11)
2. Filtering CR prevents line-ending confusion attacks
3. TAB and LF are sufficient for all whitespace needs in SVG
4. Historical issues with CR in PSD text conversion (see commit d81e4b5)
Args:
text: Input string that may contain illegal XML characters.
Returns:
Sanitized string with illegal characters replaced by spaces.
Example:
>>> safe_utf8("Hello\\x00World\\x85Test")
'Hello World Test'
>>> safe_utf8("Tab\\there\\nNewline")
'Tab\there\nNewline'
"""
return ILLEGAL_XML_RE.sub(" ", text)
[docs]
def num2str(num: int | float | bool, digit: int = DEFAULT_NUMBER_DIGITS) -> str:
"""Convert a number to a string, using the specified format for floats."""
if isinstance(num, bool):
return "true" if num else "false"
if isinstance(num, int):
return str(num)
if isinstance(num, float):
if num.is_integer():
return str(int(num))
# Format float with specified number of digits, and trim trailing zeros
number = f"{num:.{digit}f}"
return f"{number[0]}{number[1:].rstrip('0').rstrip('.')}"
raise ValueError(f"Unsupported type: {type(num)}")
[docs]
def num2str_with_unit(
num: int | float, unit: str = "px", digit: int = DEFAULT_NUMBER_DIGITS
) -> str:
"""Convert a number to a string with a CSS unit appended.
Args:
num: The numeric value to format
unit: The CSS unit to append (default: "px")
digit: Number of decimal digits for floats (default: 2)
Returns:
Formatted string with unit (e.g., "10px", "1.5em", "100%")
Examples:
>>> num2str_with_unit(10.0, "px")
'10px'
>>> num2str_with_unit(1.5, "em")
'1.5em'
>>> num2str_with_unit(100, "%")
'100%'
"""
return f"{num2str(num, digit)}{unit}"
[docs]
def seq2str(
seq: Sequence[int | float | bool],
sep: str = ",",
digit: int = DEFAULT_NUMBER_DIGITS,
) -> str:
"""Convert a sequence of numbers to a string.
Uses the specified format for floats.
"""
return sep.join(num2str(n, digit) for n in seq)
[docs]
def create_node(
tag: str,
parent: Optional[ET.Element] = None,
class_: str = "",
title: str = "",
text: str = "",
desc: str = "",
xml_space: Optional[str] = None,
**kwargs: Any,
) -> ET.Element:
"""Create an XML node with attributes.
Args:
xml_space: Set xml:space attribute with proper XML namespace.
Use "preserve" to preserve whitespace, "default" for normal behavior.
"""
node = ET.Element(tag)
if class_:
node.set("class", class_)
# Handle xml:space with proper XML namespace
if xml_space is not None:
node.set("{http://www.w3.org/XML/1998/namespace}space", xml_space)
for key, value in kwargs.items():
if value is None:
continue
key = key.rstrip("_") # allow trailing underscore for keywords
key = key.replace("_", "-") # convert underscores to hyphens
set_attribute(node, key, value)
if text:
node.text = safe_utf8(text)
if title:
create_node("title", parent=node, text=safe_utf8(title))
if desc:
create_node("desc", parent=node, text=safe_utf8(desc))
if parent is not None:
parent.append(node)
return node
[docs]
def create_xhtml_node(
tag: str,
parent: Optional[ET.Element] = None,
text: str = "",
xml_space: Optional[str] = None,
**kwargs: Any,
) -> ET.Element:
"""Create an XHTML node with proper namespace.
This helper creates elements in the XHTML namespace for use within
SVG <foreignObject> elements. The namespace is required for proper
rendering in modern browsers (Chrome, Firefox, Safari, Edge).
Note: resvg/resvg-py does not support foreignObject rendering.
Args:
tag: HTML tag name (e.g., 'div', 'p', 'span').
parent: Optional parent element to append this node to.
text: Optional text content.
xml_space: Set xml:space attribute with proper XML namespace.
Use "preserve" to preserve whitespace.
**kwargs: Additional attributes. Underscores in keys are converted
to hyphens (e.g., 'font_size' becomes 'font-size').
Returns:
XHTML element with namespace prefix.
Example:
>>> foreign_obj = create_node("foreignObject", x=0, y=0, width=100, height=50)
>>> div = create_xhtml_node("div", parent=foreign_obj)
>>> p = create_xhtml_node("p", parent=div, text="Hello")
>>> span = create_xhtml_node("span", parent=p, text="World", style="color: red")
"""
# Create element with XHTML namespace
node = ET.Element(f"{{{XHTML_NAMESPACE}}}{tag}")
# Handle xml:space with proper XML namespace
if xml_space is not None:
node.set("{http://www.w3.org/XML/1998/namespace}space", xml_space)
# Set attributes
for key, value in kwargs.items():
if value is None:
continue
key = key.rstrip("_") # allow trailing underscore for keywords
key = key.replace("_", "-") # convert underscores to hyphens
set_attribute(node, key, value)
# Set text content
if text:
node.text = safe_utf8(text)
# Append to parent if provided
if parent is not None:
parent.append(node)
return node
[docs]
def styles_to_string(styles: dict[str, str]) -> str:
"""Convert a dictionary of CSS styles to a CSS string.
Args:
styles: Dictionary of CSS property names to values.
Returns:
CSS string suitable for use in a 'style' attribute.
Example:
>>> styles = {"color": "red", "font-size": "12px"}
>>> styles_to_string(styles)
'color: red; font-size: 12px'
"""
if not styles:
return ""
return "; ".join(f"{key}: {value}" for key, value in styles.items())
def _strip_text_element_whitespace(node: ET.Element) -> None:
"""Strip whitespace-only text and tail from SVG text elements.
SVG preserves whitespace in text elements by default. When pretty-printing
adds indentation, this whitespace becomes significant and can cause
alignment issues (especially with text-anchor="end"). This function ensures
that any whitespace-only text/tail content is removed from text container
elements, leaving only the actual text content in tspan children.
"""
# Check if this is a text or tspan element
tag = node.tag
if isinstance(tag, str):
# Handle namespaced tags
local_name = tag.split("}")[-1] if "}" in tag else tag
is_text_element = local_name in ("text", "tspan", "textPath")
else:
is_text_element = False
if is_text_element:
# If element has children, clear any whitespace-only text
if len(node) > 0:
if node.text and node.text.strip() == "":
node.text = None
# Clear whitespace-only tail for all child elements
for child in node:
if child.tail and child.tail.strip() == "":
child.tail = None
# Recursively clean all child elements
for child in node:
_strip_text_element_whitespace(child)
[docs]
def fromstring(data: str) -> ET.Element:
"""Parse an XML string to an Element."""
return ET.fromstring(data)
def _fix_xhtml_namespace_prefixes(svg_string: str) -> str:
"""Remove namespace prefixes from XHTML elements in foreignObject.
ElementTree serializes XHTML namespaced elements with prefixes like
'html:div', 'html:p', 'html:span'. Browsers require unprefixed elements
with xmlns declaration for proper CSS styling.
Transforms:
<foreignObject xmlns:html="..."><html:div>...</html:div></foreignObject>
Into:
<foreignObject><div xmlns="http://www.w3.org/1999/xhtml">...</div></foreignObject>
Args:
svg_string: Serialized SVG string from ElementTree.
Returns:
SVG string with XHTML namespace prefixes removed.
Note:
- Handles multiple foreignObject elements
- Works with any namespace prefix (html, ns0, ns1, etc.)
- Only modifies XHTML elements (div, p, span)
- Adds xmlns declaration to first XHTML element in each foreignObject
"""
# XHTML namespace URL - escaped for use in regex patterns
xhtml_ns_escaped = re.escape(XHTML_NAMESPACE)
# Pattern to match XHTML namespace declarations with any prefix
ns_decl_pattern = f' xmlns:([a-zA-Z0-9]+)="{xhtml_ns_escaped}"'
# Find all prefixes used for XHTML namespace
matches = re.findall(ns_decl_pattern, svg_string)
if not matches:
# No XHTML namespace found, return unchanged
return svg_string
result = svg_string
# Process each prefix (usually just one, but handle multiple)
for prefix in set(matches): # Use set to avoid duplicates
# Remove the namespace declaration
result = re.sub(
f' xmlns:{re.escape(prefix)}="{xhtml_ns_escaped}"',
"",
result,
)
# Replace prefixed tags with unprefixed versions
# Opening tags: <prefix:tag> or <prefix:tag attr="...">
result = re.sub(
f"<{re.escape(prefix)}:(div|p|span)",
r"<\1",
result,
)
# Closing tags: </prefix:tag>
result = re.sub(
f"</{re.escape(prefix)}:(div|p|span)>",
r"</\1>",
result,
)
# Add xmlns attribute to first XHTML element after each foreignObject
# Use negative lookahead to ensure we stay within foreignObject boundaries
def add_xmlns_to_first_element(match: re.Match[str]) -> str:
"""Add xmlns declaration to first XHTML element."""
# Everything up to and including <foreignObject...> and content before tag
foreign_obj_part = match.group(1)
tag_name = match.group(2) # div, p, or span
# All attributes and spaces before the closing >
existing_attrs = match.group(3)
end_bracket = match.group(4) # The closing >
# Check if xmlns already exists anywhere in the tag (idempotency)
if f'xmlns="{XHTML_NAMESPACE}"' in existing_attrs:
return match.group(0)
xmlns_attr = f'xmlns="{XHTML_NAMESPACE}"'
# Add xmlns as the first attribute, preserving existing attributes
if existing_attrs.strip() == "":
# No attributes: <tag>
new_attrs = f" {xmlns_attr}"
else:
# Has attributes: <tag attr="...">
new_attrs = f" {xmlns_attr}{existing_attrs}"
return f"{foreign_obj_part}<{tag_name}{new_attrs}{end_bracket}"
# Apply xmlns addition to first XHTML element in each foreignObject
# Use negative lookahead (?!</foreignObject) to stay within boundaries
result = re.sub(
r"(<foreignObject[^>]*>(?:(?!</foreignObject).)*?)<(div|p|span)([^>]*)(>)",
add_xmlns_to_first_element,
result,
flags=re.DOTALL,
)
return result
[docs]
def tostring(node: ET.Element, indent: str = " ") -> str:
"""Convert an XML node to a string."""
ET.indent(node, space=indent)
_strip_text_element_whitespace(node)
svg_string = ET.tostring(node, encoding="unicode", xml_declaration=False)
# Fix XHTML namespace prefixes in foreignObject elements
# ElementTree adds prefixes (html:div), but browsers need unprefixed elements
# with xmlns declaration for proper CSS styling
svg_string = _fix_xhtml_namespace_prefixes(svg_string)
return svg_string
[docs]
def parse(file: Any) -> ET.Element:
"""Parse an XML file to an Element."""
tree = ET.parse(file)
if tree is None or tree.getroot() is None:
raise ValueError("Failed to parse XML file.")
return tree.getroot()
[docs]
def write(node: ET.Element, file: Any, indent: str = " ") -> None:
"""Write an XML node to a file."""
tree = ET.ElementTree(node)
ET.indent(tree, space=indent)
_strip_text_element_whitespace(node)
tree.write(file, encoding="unicode", xml_declaration=False)
[docs]
def add_style(node: ET.Element, key: str, value: Any) -> None:
"""Add a CSS property to an XML node."""
if "style" in node.attrib:
node.set("style", f"{node.get('style')}; {key}: {value}")
else:
node.set("style", f"{key}: {value}")
[docs]
def add_class(node: ET.Element, class_name: str) -> None:
"""Add a class to an XML node."""
if "class" in node.attrib:
classes = node.get("class", "").split()
if class_name not in classes:
classes.append(class_name)
node.set("class", " ".join(classes))
else:
node.set("class", class_name)
[docs]
def set_attribute(node: ET.Element, key: str, value: Any) -> None:
"""Add an attribute to an XML node."""
if isinstance(value, (int, float, bool)):
node.set(key, num2str(value))
elif isinstance(value, list) and all(
isinstance(v, (int, float, bool)) for v in value
):
node.set(key, seq2str(value))
else:
node.set(key, str(value))
[docs]
def append_attribute(
node: ET.Element, key: str, value: Any, separator: str = " "
) -> None:
"""Append a value to an existing attribute of an XML node."""
if key in node.attrib:
existing_value = node.get(key, "")
new_value = f"{existing_value}{separator}{value}"
node.set(key, new_value)
else:
set_attribute(node, key, value)
[docs]
def get_uri(node: ET.Element) -> str:
"""Get an uri string for the given node."""
id_ = node.get("id")
if not id_:
raise ValueError(f"Node must have an 'id' attribute to get uri: {node}")
return f"#{id_}"
[docs]
def get_funciri(node: ET.Element) -> str:
"""Get a funciri string for the given node."""
return f"url({get_uri(node)})"
[docs]
def wrap_element(
node: ET.Element, parent: ET.Element, wrapper: ET.Element
) -> ET.Element:
"""Wrap the given existing node in the wrapper element.
Usage::
wrapper = svg_utils.create_node("g")
wrapped_node = svg_utils.wrap_element(node, parent, wrapper)
Args:
node: The XML node to be wrapped.
parent: The parent XML node containing the node to be wrapped.
wrapper: The wrapper XML node.
"""
# NOTE: There is no direct way to find a parent from the node.
if node not in parent:
raise ValueError(f"Node is not a child of the given parent: {node} in {parent}")
parent.remove(node)
wrapper.append(node)
return wrapper
def _unwrap_wrapper_element(
parent: ET.Element,
wrapper: ET.Element,
insert_position: int | None = None,
) -> None:
"""Unwrap a wrapper element by moving its children to parent level.
This helper function is shared by merge_singleton_children and
merge_attribute_less_children to eliminate code duplication.
Args:
parent: Parent element containing the wrapper.
wrapper: Wrapper element to unwrap (will be removed from parent).
insert_position: If None, append grandchildren to end of parent;
else insert at this position.
The function:
- Removes the wrapper from parent
- Moves wrapper's children (grandchildren) to parent level
- Transfers wrapper's tail to the last grandchild (or parent's text)
- Preserves document order
"""
grandchildren = list(wrapper)
# Remove wrapper from parent
parent.remove(wrapper)
# Insert grandchildren at appropriate position
if insert_position is None:
# Append to end (used by merge_singleton_children)
for grandchild in grandchildren:
parent.append(grandchild)
else:
# Insert at specific position (used by merge_attribute_less_children)
for j, grandchild in enumerate(grandchildren):
parent.insert(insert_position + j, grandchild)
# Transfer wrapper's tail to last grandchild
if wrapper.tail:
if len(grandchildren) > 0:
# Add to last grandchild's tail
if insert_position is None:
last_grandchild = parent[-1]
else:
last_grandchild = parent[insert_position + len(grandchildren) - 1]
last_grandchild.tail = (last_grandchild.tail or "") + wrapper.tail
else:
# No grandchildren - handle tail appropriately
if insert_position is not None and insert_position > 0:
# Append to previous sibling's tail
prev = parent[insert_position - 1]
prev.tail = (prev.tail or "") + wrapper.tail
else:
# No previous sibling, append to parent's text
parent.text = (parent.text or "") + wrapper.tail
[docs]
def merge_attribute_less_children(element: ET.Element) -> None:
"""Recursively merge children without attributes into their parent nodes.
This utility removes redundant wrapper elements that have no attributes,
moving their text content directly into the parent element. This helps
produce cleaner, more compact SVG output.
The function preserves document order by properly handling both element.text
and child.tail to ensure text appears in the correct sequence.
Args:
element: The XML element to process recursively.
Example:
Before: <text><tspan x="10"><tspan>Hello</tspan></tspan></text>
After: <text><tspan x="10">Hello</tspan></text>
Before: <text><tspan font-weight="700">Bold</tspan><tspan> text</tspan></text>
After: <text><tspan font-weight="700">Bold</tspan> text</text>
"""
# First, recursively process all children
for child in list(element):
merge_attribute_less_children(child)
# Then merge attribute-less children into parent
# We need to find the previous element that still exists (wasn't removed)
children = list(element)
for i, child in enumerate(children):
if not child.attrib:
# If the child has its own children, unwrap it: move grandchildren up
if len(child) > 0:
# Only unwrap if there's no text content to preserve
has_text = child.text is not None and child.text.strip() != ""
if not has_text:
# Find where to insert grandchildren (at the child's position)
child_index = list(element).index(child)
# Unwrap the wrapper, inserting grandchildren at child's position
_unwrap_wrapper_element(element, child, insert_position=child_index)
continue
# Move child's text content
if child.text:
# Find previous sibling that still exists in the tree
prev_existing = None
for j in range(i - 1, -1, -1):
if children[j] in element:
prev_existing = children[j]
break
if prev_existing is not None:
# Append to previous existing sibling's tail
prev_existing.tail = (prev_existing.tail or "") + child.text
else:
# No previous sibling, append to parent's text
element.text = (element.text or "") + child.text
# Move child's tail
if child.tail:
# Find next sibling that still exists in the tree
next_existing = None
for j in range(i + 1, len(children)):
if children[j] in element:
next_existing = children[j]
break
if next_existing is not None:
# Prepend to next existing sibling's text
next_existing.text = child.tail + (next_existing.text or "")
else:
# No next sibling, find previous sibling or append to parent
prev_existing = None
for j in range(i - 1, -1, -1):
if children[j] in element:
prev_existing = children[j]
break
if prev_existing is not None:
prev_existing.tail = (prev_existing.tail or "") + child.tail
else:
element.text = (element.text or "") + child.tail
# Remove the now-empty child
element.remove(child)
[docs]
def merge_common_child_attributes(
element: ET.Element, excludes: set[str] | None = None
) -> None:
"""Recursively merge common child attributes to their parent node.
This utility hoists attributes that are common to ALL children (with the same value)
to the parent element. This helps produce cleaner, more compact SVG output by
reducing redundant attribute declarations.
Args:
element: The XML element to process recursively.
excludes: Set of attribute names that should not be hoisted to parent.
Common excludes for text elements: {"x", "y", "dx", "dy", "transform"}
Example:
Before: <text><tspan fill="red">A</tspan><tspan fill="red">B</tspan>
</text>
After: <text fill="red"><tspan>A</tspan><tspan>B</tspan></text>
Before (with excludes={"x"}):
<text><tspan x="10" fill="red">A</tspan><tspan x="20"
fill="red">B</tspan></text>
After: <text fill="red"><tspan x="10">A</tspan><tspan x="20">B
</tspan></text>
"""
if excludes is None:
excludes = set()
# First, recursively process all children
for child in list(element):
merge_common_child_attributes(child, excludes)
# Find attributes that all children have in common with the same value
children = list(element)
if not children:
return
# Start with the first child's attributes as candidates
common_attribs: dict[str, str] = {
key: value for key, value in children[0].attrib.items() if key not in excludes
}
# Check if remaining children all have the same values
for child in children[1:]:
keys_to_remove = []
for key, value in common_attribs.items():
if key not in child.attrib or child.attrib[key] != value:
keys_to_remove.append(key)
for key in keys_to_remove:
del common_attribs[key]
# Migrate the common attributes to the parent element
for key, value in common_attribs.items():
element.attrib[key] = value
for child in element:
del child.attrib[key]
[docs]
def merge_consecutive_siblings(element: ET.Element) -> None:
"""Recursively merge consecutive sibling elements with identical attributes.
This utility consolidates sequences of consecutive child elements that have
the same tag and identical attributes by merging their text content. This is
particularly useful for optimizing SVG <tspan> elements.
Args:
element: The XML element to process recursively.
Example:
Before: <text>
<tspan font-size="18" letter-spacing="0.72">す</tspan>
<tspan font-size="18" letter-spacing="0.72">だ</tspan>
<tspan font-size="18" letter-spacing="0.72">け</tspan>
</text>
After: <text>
<tspan font-size="18" letter-spacing="0.72">すだけ
</tspan>
</text>
Before: <text>
<tspan font-size="18" baseline-shift="-0.36"
letter-spacing="0.72">さ</tspan>
<tspan font-size="18" letter-spacing="0.72">す</tspan>
<tspan font-size="18" letter-spacing="0.72">だ</tspan>
</text>
After: <text>
<tspan font-size="18" baseline-shift="-0.36"
letter-spacing="0.72">さ</tspan>
<tspan font-size="18" letter-spacing="0.72">すだ</tspan>
</text>
Note:
- Only merges elements with the same tag name
- Only merges elements with identical attribute sets
- Preserves document order
- Does not merge elements with child elements
- Empty elements (no text or tail) are removed
"""
# First, recursively process all children
for child in list(element):
merge_consecutive_siblings(child)
# Then merge consecutive siblings with identical attributes
children = list(element)
if len(children) < 2:
return
i = 0
while i < len(children):
current = children[i]
# Skip if current element has children (don't merge complex structures)
if len(current) > 0:
i += 1
continue
# Find consecutive siblings with same tag and attributes
merge_group = [current]
j = i + 1
while j < len(children):
next_elem = children[j]
# Stop if next element has children
if len(next_elem) > 0:
break
# Check if tags and attributes match
if next_elem.tag == current.tag and next_elem.attrib == current.attrib:
merge_group.append(next_elem)
j += 1
else:
break
# If we found consecutive siblings to merge (2 or more)
if len(merge_group) >= 2:
# Merge all text content into the first element
merged_text = ""
for elem in merge_group:
if elem.text:
merged_text += elem.text
# Set merged text to first element (or None if empty)
merge_group[0].text = merged_text if merged_text else None
# Preserve the tail of the last element in the group
last_tail = merge_group[-1].tail
# Remove all but the first element
for elem in merge_group[1:]:
element.remove(elem)
# Set the tail to the first element
merge_group[0].tail = last_tail
# If the merged element is empty (no text, no tail), remove it
if not merge_group[0].text and not merge_group[0].tail:
element.remove(merge_group[0])
# Update children list after removals
children = list(element)
else:
i += 1
[docs]
def merge_singleton_children(element: ET.Element) -> None:
"""Recursively merge singleton child nodes into their parent nodes.
This utility removes redundant wrapper elements when a parent has exactly one child
and there are no conflicting attributes. The child's attributes and text content
are moved to the parent element.
Args:
element: The XML element to process recursively.
Example:
Before: <text><tspan>Hello</tspan></text>
After: <text>Hello</text>
Before: <text x="10"><tspan font-weight="700">Bold</tspan></text>
After: <text x="10" font-weight="700">Bold</text>
Not merged (conflicting attributes):
Before: <text x="10"><tspan x="20">Text</tspan></text>
After: <text x="10"><tspan x="20">Text</tspan></text> (unchanged)
Not merged (child has children):
Before: <text><tspan><tspan>A</tspan><tspan>B</tspan></tspan></text>
After: <text><tspan><tspan>A</tspan><tspan>B</tspan></tspan></text>
(unchanged)
"""
# First, recursively process all children
for child in list(element):
merge_singleton_children(child)
# Merge singleton child if present (checking AFTER recursion)
if len(element) == 1:
child = element[0]
# If the child has its own children, we can still optimize by
# "unwrapping" it: move the grandchildren up to be direct children of
# element, removing the wrapper
if len(child) > 0:
# Only unwrap if there are no attribute conflicts and no text
# content to preserve
has_conflict = (
len(set(element.attrib.keys()) & set(child.attrib.keys())) > 0
)
has_text = child.text is not None and child.text.strip() != ""
if not has_conflict and not has_text:
# Move child's attributes to parent
for key, value in child.attrib.items():
element.attrib[key] = value
# Unwrap the wrapper, appending grandchildren to end
_unwrap_wrapper_element(element, child, insert_position=None)
return
# Check for attribute conflicts
if len(set(element.attrib.keys()) & set(child.attrib.keys())) > 0:
return # Conflicting attributes, do not merge
# Move child's text content to parent
# Note: child.text comes before any children in document order
if child.text:
element.text = (element.text or "") + child.text
# Move child's tail to parent's text
# Note: When we remove the child, its tail (which comes after </child>)
# should become part of the parent's text content
if child.tail:
element.text = (element.text or "") + child.tail
# Move all child's attributes to parent
for key, value in child.attrib.items():
element.attrib[key] = value
# Remove the now-merged child
element.remove(child)
[docs]
def consolidate_defs(svg: ET.Element) -> None:
"""Consolidate all <defs> and definition elements into a global <defs>.
This optimization improves SVG structure by:
1. Creating a single global <defs> element at the beginning of the SVG
2. Moving all definition elements (filters, gradients, patterns, etc.) into it
3. Removing now-empty inline <defs> elements
Definition elements that are moved:
- <defs> (contents merged into global defs)
- <filter>
- <linearGradient>
- <radialGradient>
- <pattern>
- <clipPath>
- <marker>
- <symbol>
Args:
svg: The root SVG element to optimize (modified in-place).
Note:
This function preserves all id references and element ordering within defs.
**Mask Element Exclusion:**
This function does NOT move <mask> elements for the following technical reasons:
1. **Coordinate Systems**: Masks use ``maskContentUnits="userSpaceOnUse"``
by default, meaning mask content coordinates are evaluated at reference
time (where ``mask="url(#id)"`` is applied), not at definition time.
Moving masks to a different DOM position could affect coordinate system
evaluation.
2. **Property Inheritance**: Masks inherit CSS properties (fill, stroke,
opacity, etc.) from their DOM ancestors. Moving a mask to <defs> changes
its inheritance chain, which can alter rendering for masks with styled
content.
3. **Transform Handling**: psd2svg uses special transform workarounds for
masked elements that rely on specific mask positioning in the DOM tree.
4. **Renderer Compatibility**: Known compatibility issues exist with some
SVG renderers regarding mask positioning and interactions with other
elements (e.g., clipPath-mask interactions).
**Nested Definitions Are Consolidated:**
While <mask> elements themselves are not moved, any nested <defs> and
definition elements (filters, gradients, etc.) WITHIN masks ARE moved
to the global defs. This provides optimization benefits while preserving
mask rendering fidelity. See ``test_consolidate_complex_real_world_example``
in tests/test_optimize.py for demonstration.
Example:
Before:
<svg>
<rect fill="url(#g1)"/>
<linearGradient id="g1">...</linearGradient>
<defs><filter id="f1">...</filter></defs>
</svg>
After:
<svg>
<defs>
<linearGradient id="g1">...</linearGradient>
<filter id="f1">...</filter>
</defs>
<rect fill="url(#g1)"/>
</svg>
"""
# Definition element tags to consolidate (without namespace)
definition_tags = {
"filter",
"linearGradient",
"radialGradient",
"pattern",
"clipPath",
"marker",
"symbol",
}
# Find or create global <defs> element at the beginning
global_defs = None
for child in svg:
tag = child.tag
local_name = tag.split("}")[-1] if "}" in tag else tag
if local_name == "defs":
global_defs = child
break
if global_defs is None:
# Create new global defs as first child
global_defs = ET.Element("defs")
svg.insert(0, global_defs)
else:
# Ensure global_defs is the first child
if svg[0] is not global_defs:
svg.remove(global_defs)
svg.insert(0, global_defs)
# Collect all definition elements from the entire SVG tree
definitions_to_move: list[tuple[ET.Element, ET.Element]] = [] # (element, parent)
def collect_definitions(element: ET.Element) -> None:
"""Recursively collect definition elements to move."""
for child in list(element):
tag = child.tag
local_name = tag.split("}")[-1] if "}" in tag else tag
# If it's a <defs> element (not the global one), collect its
# children
if local_name == "defs" and child is not global_defs:
for def_child in list(child):
definitions_to_move.append((def_child, child))
# Continue recursing into defs (in case of nested
# structures)
collect_definitions(child)
# If it's a definition element (not inside global defs), collect it
elif local_name in definition_tags and element is not global_defs:
definitions_to_move.append((child, element))
# Still recurse into the element (e.g., filters can contain
# other elements)
collect_definitions(child)
else:
# Recurse into other elements
collect_definitions(child)
# Collect definitions directly from SVG root level
for child in list(svg):
if child is global_defs:
continue
tag = child.tag
local_name = tag.split("}")[-1] if "}" in tag else tag
# If it's a <defs> element, collect its children
if local_name == "defs":
for def_child in list(child):
definitions_to_move.append((def_child, child))
# Also recurse into it
collect_definitions(child)
# If it's a definition element, collect it
elif local_name in definition_tags:
definitions_to_move.append((child, svg))
# Also recurse into it
collect_definitions(child)
else:
# For non-definition elements, recurse to find nested definitions
collect_definitions(child)
# Move all collected definitions to global defs
for element, parent in definitions_to_move:
# Skip if element has already been moved (e.g., parent was removed)
if element not in parent:
continue
parent.remove(element)
global_defs.append(element)
# Remove now-empty <defs> elements
def remove_empty_defs(element: ET.Element) -> None:
"""Recursively remove empty <defs> elements."""
for child in list(element):
tag = child.tag
local_name = tag.split("}")[-1] if "}" in tag else tag
if local_name == "defs" and child is not global_defs and len(child) == 0:
element.remove(child)
else:
remove_empty_defs(child)
remove_empty_defs(svg)
# Remove global defs if it's empty (no definitions were found)
if len(global_defs) == 0:
svg.remove(global_defs)
def _normalize_element_for_comparison(element: ET.Element) -> str:
"""Create normalized string for structural comparison.
This function creates a canonical string representation of an element
for content-based comparison, excluding identity attributes like 'id'.
Normalizations applied:
- Remove id attribute (identity, not equality)
- Sort attributes alphabetically (consistent ordering)
- Recursively sort child attributes
- Strip whitespace from text and tail
- Serialize without whitespace
Args:
element: The element to normalize.
Returns:
Canonical string representation suitable for deduplication.
"""
# Deep copy to avoid modifying original
temp = copy.deepcopy(element)
# Remove id attribute - it represents identity, not content equality
temp.attrib.pop("id", None)
# Sort all attributes and normalize whitespace recursively
# (XML doesn't guarantee attribute order, and whitespace can vary)
for elem in temp.iter():
sorted_attrib = dict(sorted(elem.attrib.items()))
elem.attrib.clear()
elem.attrib.update(sorted_attrib)
# Normalize text and tail whitespace
if elem.text:
elem.text = elem.text.strip() or None
if elem.tail:
elem.tail = elem.tail.strip() or None
# Serialize to string for comparison
return ET.tostring(temp, encoding="unicode", method="xml")
def _update_url_references(
svg: ET.Element,
id_mapping: dict[str, str],
) -> None:
"""Update all url(#id) references using id mapping.
Searches the SVG tree for elements with URL reference attributes
(fill, stroke, filter, clip-path, mask, marker-*) and updates
any url(#id) references according to the provided mapping.
Note:
The converter only creates URL references in direct attributes,
never in style attributes. Style attributes only contain
font-family properties.
Args:
svg: The root SVG element to search.
id_mapping: Mapping from old element IDs to canonical IDs.
"""
# Attributes that can contain url() references
url_attrs = {
"fill",
"stroke",
"filter",
"clip-path",
"mask",
"marker-start",
"marker-mid",
"marker-end",
}
# Pattern to match url(#id)
url_pattern = re.compile(r"url\(#([^)]+)\)")
updated_count = 0
# Only iterate over elements that actually have URL reference attributes
# This is more efficient than checking every element in the tree
for attr in url_attrs:
# Use XPath to find elements with this attribute
for element in svg.findall(f".//*[@{attr}]"):
value = element.get(attr)
if value and "url(#" in value:
def replace_url(match: re.Match[str]) -> str:
nonlocal updated_count
old_id = match.group(1)
if old_id in id_mapping:
new_id = id_mapping[old_id]
if new_id != old_id:
updated_count += 1
return f"url(#{new_id})"
return match.group(0)
new_value = url_pattern.sub(replace_url, value)
if new_value != value:
element.set(attr, new_value)
if updated_count > 0:
logger.debug(f"Updated {updated_count} url(#id) references")
[docs]
def deduplicate_definitions(svg: ET.Element) -> None:
"""Deduplicate identical definition elements in <defs>.
Identifies structurally identical definition elements (filter, gradients,
patterns, clipPath, marker, symbol) and merges duplicates by keeping the
first occurrence and updating all url(#id) references throughout the SVG tree.
This function should be called AFTER consolidate_defs() to ensure all
definition elements are in a single global <defs>.
Priority order (based on PSD per-layer structure):
1. filter, linearGradient, radialGradient, pattern (most common duplicates)
2. clipPath, marker, symbol (less common but still beneficial)
Args:
svg: The root SVG element (modified in-place).
Note:
Elements are considered identical if they have:
- Same tag
- Same attributes (except 'id')
- Same child structure and content
Example:
Before:
<defs>
<linearGradient id="g1"><stop offset="0%"/></linearGradient>
<linearGradient id="g2"><stop offset="0%"/></linearGradient>
</defs>
<rect fill="url(#g1)"/>
<circle fill="url(#g2)"/>
After:
<defs>
<linearGradient id="g1"><stop offset="0%"/></linearGradient>
</defs>
<rect fill="url(#g1)"/>
<circle fill="url(#g1)"/>
"""
# Phase 1: Find global <defs> and extract definition elements
global_defs = None
for child in svg:
tag = child.tag
local_name = tag.split("}")[-1] if "}" in tag else tag
if local_name == "defs":
global_defs = child
break
if global_defs is None:
return
# All definition types from consolidate_defs
# Prioritize high-duplication types (filters, gradients, patterns)
definition_tags = {
"filter",
"linearGradient",
"radialGradient",
"pattern",
"clipPath",
"marker",
"symbol",
}
definitions = []
for child in global_defs:
tag = child.tag
local_name = tag.split("}")[-1] if "}" in tag else tag
if local_name in definition_tags:
definitions.append(child)
if not definitions:
return
# Phase 2: Build canonical mapping using normalized serialization
content_to_id: dict[str, str] = {} # normalized_content -> canonical_id
id_to_canonical: dict[str, str] = {} # element_id -> canonical_id
duplicates_to_remove: list[ET.Element] = []
for element in definitions:
element_id = element.get("id")
if not element_id:
logger.warning(f"Definition element missing id: {element.tag}")
continue
# Normalize element for comparison (remove id, sort attributes, serialize)
normalized = _normalize_element_for_comparison(element)
if normalized in content_to_id:
# Duplicate found - map to canonical (first occurrence)
canonical_id = content_to_id[normalized]
id_to_canonical[element_id] = canonical_id
duplicates_to_remove.append(element)
else:
# First occurrence - this becomes the canonical element
content_to_id[normalized] = element_id
id_to_canonical[element_id] = element_id # Maps to itself
# Phase 3: Update all url(#id) references
if id_to_canonical:
_update_url_references(svg, id_to_canonical)
# Phase 4: Remove duplicates from defs
for element in duplicates_to_remove:
global_defs.remove(element)
if duplicates_to_remove:
logger.debug(f"Deduplicated {len(duplicates_to_remove)} definition elements")
def _is_unwrappable_group(group: ET.Element) -> bool:
"""Check if a <g> element can be safely unwrapped.
Returns True if group has no attributes (or only empty class), and no <title>
child elements. Returns False if group has ANY meaningful attributes or contains
<title> children that could conflict if moved up.
Meaningful attributes that prevent unwrapping:
- id (may be referenced by <use> or JavaScript)
- opacity (affects rendering)
- style (may contain mix-blend-mode, isolation, etc.)
- filter, clip-path, mask (affect rendering)
- transform (affects coordinate space)
- Any other unknown attributes (conservative approach)
Args:
group: The <g> element to check.
Returns:
True if safe to unwrap, False otherwise.
"""
# Check attributes first
if group.attrib:
# Check each attribute
for attr_name, attr_value in group.attrib.items():
# Strip namespace from attribute name
local_attr = attr_name.split("}")[-1] if "}" in attr_name else attr_name
# Empty class is safe (debugging attribute, usually disabled)
if local_attr == "class":
if not attr_value or attr_value.strip() == "":
continue
return False # Non-empty class
# Any other attribute prevents unwrapping
return False
# Check for <title> children that would conflict if moved up
for child in group:
tag = child.tag
local_tag = tag.split("}")[-1] if "}" in tag else tag
if local_tag == "title":
return False # Preserve groups with <title> children
return True
def _unwrap_groups_recursive(element: ET.Element) -> None:
"""Recursively unwrap attribute-less <g> elements (helper for unwrap_groups).
Processes the tree bottom-up: children first, then parent. This allows
nested unwrappable groups to be eliminated in a single traversal.
Args:
element: Element to process recursively.
"""
# Step 1: Recursively process all children first (bottom-up)
for child in list(element):
_unwrap_groups_recursive(child)
# Step 2: After recursion, check for unwrappable <g> children
children = list(element)
for i, child in enumerate(children):
# Get local tag name (strip namespace)
tag = child.tag
local_name = tag.split("}")[-1] if "}" in tag else tag
# Only process <g> elements
if local_name != "g":
continue
# Check if this group can be unwrapped
if not _is_unwrappable_group(child):
continue
# Empty group after recursion? Remove it entirely
if len(child) == 0:
element.remove(child)
continue
# Unwrap: move grandchildren up to parent level
# Find current position of child in parent
child_index = list(element).index(child)
# Reuse existing helper (handles tail, order, etc.)
_unwrap_wrapper_element(element, child, insert_position=child_index)
[docs]
def unwrap_groups(svg: ET.Element) -> None:
"""Unwrap <g> elements that have no meaningful attributes.
This optimization removes redundant <g> wrapper elements that don't affect
rendering, moving their children directly to the parent level. This reduces
SVG nesting depth and file size.
Groups are unwrapped if they have NO attributes (or only empty class), AND
no <title> child elements. Groups with rendering attributes (opacity, style,
filter, mask, clip-path, transform) or identity attributes (id) are preserved.
Empty groups (no children) are removed entirely.
Args:
svg: The root SVG element to optimize (modified in-place).
Example:
Before:
<svg>
<g>
<rect x="0" y="0" width="100" height="100"/>
</g>
</svg>
After:
<svg>
<rect x="0" y="0" width="100" height="100"/>
</svg>
Preserved (has opacity):
<svg>
<g opacity="0.5">
<rect x="0" y="0" width="100" height="100"/>
</g>
</svg>
Note:
This function is automatically called when optimize=True in save()
and tostring(). It processes the tree recursively, unwrapping all
eligible groups in a single pass.
"""
_unwrap_groups_recursive(svg)
[docs]
def find_elements_with_font_family(
svg: ET.Element, font_family: str, include_inherited: bool = True
) -> list[ET.Element]:
"""Find all text/tspan and XHTML text elements that use the given font family.
This function searches for text, tspan, and XHTML text elements (p, span from
foreignObject) that have the specified font-family applied, either directly via
attributes or through CSS inheritance from parent elements.
Args:
svg: SVG element tree to search.
font_family: Font family name to search for (case-insensitive).
include_inherited: If True, include elements that inherit the font from
parents.
If False, only include elements with direct font-family
declarations. Default is True for backward compatibility.
Returns:
List of text/tspan/p/span elements that use the specified font family.
Note:
- Searches both font-family attributes and style attributes
- Supports CSS inheritance (walks up parent chain) when
include_inherited=True
- Case-insensitive font family matching
- Only returns text and tspan elements (not their parents)
Example:
>>> svg = svg_utils.fromstring(
... '<svg><text font-family="Arial">Hi</text></svg>'
... )
>>> elements = find_elements_with_font_family(svg, "Arial")
>>> len(elements)
1
>>> elements = find_elements_with_font_family(
... svg, "Arial", include_inherited=False
... )
>>> len(elements)
1
"""
matching_elements: list[ET.Element] = []
font_family_lower = font_family.lower()
# Build parent map for inheritance lookup
parent_map = {c: p for p in svg.iter() for c in p}
for element in svg.iter():
# Get local tag name (strip namespace if present)
tag = element.tag
if "}" in tag:
tag = tag.split("}", 1)[1]
# Only process text/tspan and XHTML text elements (p, span from foreignObject)
if tag not in ("text", "tspan", "p", "span"):
continue
# Check if element uses target font (with or without inheritance)
if _element_uses_font_family(
element, parent_map, font_family_lower, include_inherited
):
matching_elements.append(element)
return matching_elements
def _element_uses_font_family(
element: ET.Element,
parent_map: dict[ET.Element, ET.Element],
target_font: str,
include_inherited: bool = True,
) -> bool:
"""Check if element uses the target font (direct or inherited).
Helper function for find_elements_with_font_family that walks up the
parent
chain to check for font-family declarations.
Args:
element: Element to check.
parent_map: Dictionary mapping children to parents.
target_font: Target font family name (lowercase).
include_inherited: If True, walk up parent chain. If False, check only
this element.
Returns:
True if element uses the target font, False otherwise.
Note:
Ignores CSS 'font' shorthand property, only checks 'font-family'.
"""
current = element
while True:
# Check direct font-family attribute
elem_font_family = current.get("font-family")
if elem_font_family:
clean_family = elem_font_family.strip("'\"").split(",")[0].strip("'\"")
return clean_family.lower() == target_font
# Check style attribute for font-family
style = current.get("style")
if style and "font-family:" in style:
match = re.search(r"font-family:\s*([^;]+)", style)
if match:
font_family_value = match.group(1).strip()
families = [
f.strip().strip("'\"") for f in font_family_value.split(",")
]
if families:
return families[0].lower() == target_font
# If not checking inheritance, stop here
if not include_inherited:
return False
# Walk up to parent
if current not in parent_map:
break
current = parent_map[current]
return False
[docs]
def replace_font_family(
element: ET.Element, old_font_family: str, new_font_family: str
) -> None:
"""Replace a font family with another in the given element.
This function modifies an SVG text/tspan element by replacing occurrences
of old_font_family with new_font_family in font-family specifications.
It handles both font-family attributes and style attributes.
Args:
element: Element to update (typically text or tspan element).
old_font_family: Font family name to replace.
new_font_family: Font family name to use as replacement.
Example:
Before: <text font-family="ArialMT">Hello</text>
After: <text font-family="Arial">Hello</text>
Before: <text font-family="ArialMT, Helvetica">Hello</text>
After: <text font-family="Arial, Helvetica">Hello</text>
Before: <text style="font-family: ArialMT">Hello</text>
After: <text style="font-family: 'Arial'">Hello</text>
Note:
- Replaces first occurrence of old_font_family in comma-separated list
- If old_font_family not found, does nothing
- For font-family attributes: no quotes (attribute delimiter is sufficient)
- For style attributes: quotes font names for CSS compliance
- Updates font-family attribute with higher priority than style
"""
# Check font-family attribute first (higher priority)
existing_font_family = element.get("font-family")
if existing_font_family:
# Parse existing font families
families = [f.strip().strip("'\"") for f in existing_font_family.split(",")]
# Replace old font with new font
if old_font_family in families:
families = [
new_font_family if f == old_font_family else f for f in families
]
# No quotes for SVG attributes (attribute delimiter is sufficient)
element.set("font-family", ", ".join(families))
return # Attribute has priority, don't check style
# Check style attribute for font-family
style = element.get("style")
if style and "font-family:" in style:
def replace_in_style(match: re.Match[str]) -> str:
font_family_value = match.group(1).strip()
# Parse existing font families
families = [f.strip().strip("'\"") for f in font_family_value.split(",")]
# Replace old font with new font
if old_font_family in families:
families = [
new_font_family if f == old_font_family else f for f in families
]
# Quote all families for CSS compliance
return "font-family: " + ", ".join(f"'{f}'" for f in families)
# Replace font-family in style attribute
updated_style = re.sub(r"font-family:\s*([^;]+)", replace_in_style, style)
if updated_style != style:
element.set("style", updated_style)
[docs]
def add_font_family(element: ET.Element, font_family: str) -> None:
"""Add font family to font-family specification in the given element.
This function modifies an SVG text/tspan element by appending a font family
to its existing font-family specification if not already present. It handles both
font-family attributes and font-family declarations within style attributes.
Behavior:
- If element has font-family attribute: append font to it
- Else if element has font-family in style: append font to style
- Else: create new font-family attribute with the font
- If both attribute and style exist: append to attribute (higher priority)
Args:
element: Element to update (typically text or tspan element).
font_family: Font family name to add.
Example:
Before: <text font-family="Arial">Hello</text>
After: <text font-family="Arial, Helvetica">Hello</text>
Before: <text style="font-family: Arial">Hello</text>
After: <text style="font-family: 'Arial', 'Helvetica'">Hello</text>
Before: <text>Hello</text>
After: <text font-family="Helvetica">Hello</text>
Note:
- Updates font-family attribute with higher priority than style
- For font-family attributes: no quotes (attribute delimiter is sufficient)
- For style attributes: quotes font names for CSS compliance
- Idempotent: does not add font if already present in the chain
- Ignores CSS 'font' shorthand property, only handles 'font-family'
"""
# Check font-family attribute first (higher priority)
existing_font_family = element.get("font-family")
if existing_font_family:
# Parse existing font families
families = [f.strip().strip("'\"") for f in existing_font_family.split(",")]
# Only add if not already present
if font_family not in families:
families.append(font_family)
# No quotes for SVG attributes (attribute delimiter is sufficient)
element.set("font-family", ", ".join(families))
return # Attribute has priority, don't check style
# Check style attribute for font-family
style = element.get("style")
if style and "font-family:" in style:
def replace_font_family(match: re.Match[str]) -> str:
font_family_value = match.group(1).strip()
# Parse existing font families
families = [f.strip().strip("'\"") for f in font_family_value.split(",")]
# Only add if not already present
if font_family not in families:
families.append(font_family)
# Quote all families for CSS compliance
return "font-family: " + ", ".join(f"'{f}'" for f in families)
# Replace font-family in style attribute
updated_style = re.sub(r"font-family:\s*([^;]+)", replace_font_family, style)
if updated_style != style:
element.set("style", updated_style)
return
# No font-family found, create new attribute
element.set("font-family", font_family)
[docs]
def insert_or_update_style_element(svg: ET.Element, css_content: str) -> None:
"""Insert or update a <style> element in the SVG root.
Args:
svg: SVG root element to modify.
css_content: CSS content to insert or append.
Note:
- If a <style> element exists as first child, appends to it
- Otherwise creates a new <style> element as first child
- Idempotent: skips if CSS content already present
"""
# Find existing <style> element (check first child only)
style_element = None
if len(svg) > 0:
first_child = svg[0]
# Check if it's a style element (handle namespaced tags)
tag = first_child.tag
local_name = tag.split("}")[-1] if "}" in tag else tag
if local_name == "style":
style_element = first_child
if style_element is not None:
# Update existing style element
existing_text = style_element.text or ""
# Check if our fonts are already embedded (idempotent check)
if css_content in existing_text:
logger.debug("CSS content already present, skipping")
return
# Append to existing styles
style_element.text = existing_text + "\n" + css_content
else:
# Create new style element as first child
style_element = ET.Element("style")
style_element.text = css_content
svg.insert(0, style_element)