Source code for hed.schema.schema_io.xml2schema

"""
This module is used to create a HedSchema object from an XML file or tree.
"""

from defusedxml import ElementTree
import xml

import hed.schema.hed_schema_constants
from hed.errors.exceptions import HedFileError, HedExceptions
from hed.schema.hed_schema_constants import HedSectionKey, HedKey
from hed.schema import schema_validation_util
from hed.schema.schema_io import xml_constants
from .base2schema import SchemaLoader
from functools import partial


[docs]class SchemaLoaderXML(SchemaLoader): """ Loads XML schemas from filenames or strings. Expected usage is SchemaLoaderXML.load(filename) SchemaLoaderXML(filename) will load just the header_attributes """
[docs] def __init__(self, filename, schema_as_string=None): super().__init__(filename, schema_as_string) self._root_element = None self._parent_map = {}
def _open_file(self): """Parses an XML file and returns the root element. Parameters ---------- Returns ------- RestrictedElement The root element of the HED XML file. """ try: if self.filename: hed_xml_tree = ElementTree.parse(self.filename) root = hed_xml_tree.getroot() else: root = ElementTree.fromstring(self.schema_as_string) except xml.etree.ElementTree.ParseError as e: raise HedFileError(HedExceptions.CANNOT_PARSE_XML, e.msg, self.schema_as_string) return root def _get_header_attributes(self, root_element): """ Gets the schema attributes form the XML root node Returns ------- attribute_dict: {str: str} """ return self._reformat_xsd_attrib(root_element.attrib) def _parse_data(self): self._root_element = self.input_data self._parent_map = {c: p for p in self._root_element.iter() for c in p} parse_order = { HedSectionKey.Properties: partial(self._populate_section, HedSectionKey.Properties), HedSectionKey.Attributes: partial(self._populate_section, HedSectionKey.Attributes), HedSectionKey.UnitModifiers: partial(self._populate_section, HedSectionKey.UnitModifiers), HedSectionKey.UnitClasses: self._populate_unit_class_dictionaries, HedSectionKey.ValueClasses: partial(self._populate_section, HedSectionKey.ValueClasses), HedSectionKey.Tags: self._populate_tag_dictionaries, } self._schema.prologue = self._read_prologue() self._schema.epilogue = self._read_epilogue() self._parse_sections(self._root_element, parse_order) def _parse_sections(self, root_element, parse_order): for section_key in parse_order: section_name = xml_constants.SECTION_ELEMENTS[section_key] section_element = self._get_elements_by_name(section_name, root_element) if section_element: section_element = section_element[0] if isinstance(section_element, list): raise HedFileError(HedExceptions.INVALID_HED_FORMAT, "Attempting to load an outdated or invalid XML schema", self.filename) parse_func = parse_order[section_key] parse_func(section_element) def _populate_section(self, key_class, section): self._schema._initialize_attributes(key_class) def_element_name = xml_constants.ELEMENT_NAMES[key_class] attribute_elements = self._get_elements_by_name(def_element_name, section) for element in attribute_elements: new_entry = self._parse_node(element, key_class) self._add_to_dict(new_entry, key_class) def _read_prologue(self): prologue_elements = self._get_elements_by_name(xml_constants.PROLOGUE_ELEMENT) if len(prologue_elements) == 1: return prologue_elements[0].text return "" def _read_epilogue(self): epilogue_elements = self._get_elements_by_name(xml_constants.EPILOGUE_ELEMENT) if len(epilogue_elements) == 1: return epilogue_elements[0].text return "" def _add_tags_recursive(self, new_tags, parent_tags): for tag_element in new_tags: current_tag = self._get_element_tag_value(tag_element) parents_and_child = parent_tags + [current_tag] full_tag = "/".join(parents_and_child) tag_entry = self._parse_node(tag_element, HedSectionKey.Tags, full_tag) rooted_entry = schema_validation_util.find_rooted_entry(tag_entry, self._schema, self._loading_merged) if rooted_entry: loading_from_chain = rooted_entry.name + "/" + tag_entry.short_tag_name loading_from_chain_short = tag_entry.short_tag_name full_tag = full_tag.replace(loading_from_chain_short, loading_from_chain) tag_entry = self._parse_node(tag_element, HedSectionKey.Tags, full_tag) parents_and_child = full_tag.split("/") self._add_to_dict(tag_entry, HedSectionKey.Tags) child_tags = tag_element.findall("node") self._add_tags_recursive(child_tags, parents_and_child) def _populate_tag_dictionaries(self, tag_section): """Populates a dictionary of dictionaries associated with tags and their attributes. Parameters ---------- Returns ------- {} A dictionary of dictionaries that has been populated with dictionaries associated with tag attributes. """ self._schema._initialize_attributes(HedSectionKey.Tags) root_tags = tag_section.findall("node") self._add_tags_recursive(root_tags, []) def _populate_unit_class_dictionaries(self, unit_section): """Populates a dictionary of dictionaries associated with all the unit classes, unit class units, and unit class default units. Parameters ---------- Returns ------- {} A dictionary of dictionaries associated with all the unit classes, unit class units, and unit class default units. """ self._schema._initialize_attributes(HedSectionKey.UnitClasses) self._schema._initialize_attributes(HedSectionKey.Units) def_element_name = xml_constants.ELEMENT_NAMES[HedSectionKey.UnitClasses] unit_class_elements = self._get_elements_by_name(def_element_name, unit_section) for unit_class_element in unit_class_elements: unit_class_entry = self._parse_node(unit_class_element, HedSectionKey.UnitClasses) unit_class_entry = self._add_to_dict(unit_class_entry, HedSectionKey.UnitClasses) element_units = self._get_elements_by_name(xml_constants.UNIT_CLASS_UNIT_ELEMENT, unit_class_element) element_unit_names = [self._get_element_tag_value(element) for element in element_units] for unit, element in zip(element_unit_names, element_units): unit_class_unit_entry = self._parse_node(element, HedSectionKey.Units) self._add_to_dict(unit_class_unit_entry, HedSectionKey.Units) unit_class_entry.add_unit(unit_class_unit_entry) def _reformat_xsd_attrib(self, attrib_dict): final_attrib = {} for attrib_name in attrib_dict: if attrib_name == xml_constants.NO_NAMESPACE_XSD_KEY: xsd_value = attrib_dict[attrib_name] final_attrib[hed.schema.hed_schema_constants.NS_ATTRIB] = xml_constants.XSI_SOURCE final_attrib[hed.schema.hed_schema_constants.NO_LOC_ATTRIB] = xsd_value else: final_attrib[attrib_name] = attrib_dict[attrib_name] return final_attrib def _parse_node(self, node_element, key_class, element_name=None): if element_name: node_name = element_name else: node_name = self._get_element_tag_value(node_element) attribute_desc = self._get_element_tag_value(node_element, xml_constants.DESCRIPTION_ELEMENT) tag_entry = self._schema._create_tag_entry(node_name, key_class) if attribute_desc: tag_entry.description = attribute_desc for attribute_element in node_element: if attribute_element.tag != xml_constants.ATTRIBUTE_PROPERTY_ELEMENTS[key_class]: continue attribute_name = self._get_element_tag_value(attribute_element) attribute_value_elements = self._get_elements_by_name("value", attribute_element) attribute_value = ",".join(element.text for element in attribute_value_elements) # Todo: do we need to validate this here? if not attribute_value: attribute_value = True tag_entry._set_attribute_value(attribute_name, attribute_value) return tag_entry def _get_element_tag_value(self, element, tag_name=xml_constants.NAME_ELEMENT): """ Get the value of the element's tag. Parameters: element (Element): A element in the HED XML file. tag_name (str): The name of the XML element's tag. The default is 'name'. Returns: str: The value of the element's tag. Notes: If the element doesn't have the tag then it will return an empty string. """ element = element.find(tag_name) if element is not None: if element.text is None and tag_name != "units": raise HedFileError(HedExceptions.HED_SCHEMA_NODE_NAME_INVALID, f"A Schema node is empty for tag of element name: '{tag_name}'.", self._schema.filename) return element.text return "" def _get_elements_by_name(self, element_name='node', parent_element=None): """ Get the elements that have a specific element name. Parameters: element_name (str): The name of the element. The default is 'node'. parent_element (RestrictedElement or None): The parent element. Returns: list: A list containing elements that have a specific element name. Notes: If a parent element is specified then only the children of the parent will be returned with the given 'element_name'. If not specified the root element will be the parent. """ if parent_element is None: elements = self._root_element.findall('.//%s' % element_name) else: elements = parent_element.findall('.//%s' % element_name) return elements def _add_to_dict(self, entry, key_class): if entry.has_attribute(HedKey.InLibrary) and not self._loading_merged: raise HedFileError(HedExceptions.IN_LIBRARY_IN_UNMERGED, f"Library tag in unmerged schema has InLibrary attribute", self._schema.filename) return self._schema._add_tag_to_dict(entry.name, entry, key_class)