diff --git a/src/intel/genxml/gen_sort_tags.py b/src/intel/genxml/gen_sort_tags.py index 2e74c3225bc..dcdba333fe6 100755 --- a/src/intel/genxml/gen_sort_tags.py +++ b/src/intel/genxml/gen_sort_tags.py @@ -24,18 +24,14 @@ def main() -> None: if not args.quiet: print('Processing {}... '.format(filename), end='', flush=True) - xml = et.parse(filename) - original = copy.deepcopy(xml) if args.validate else xml - intel_genxml.sort_xml(xml) + genxml = intel_genxml.GenXml(filename) if args.validate: - for old, new in zip(original.getroot(), xml.getroot()): - assert intel_genxml.node_validator(old, new), f'{filename} is invalid, run gen_sort_tags.py and commit that' + assert genxml.is_equivalent_xml(genxml.sorted_copy()), \ + f'{filename} is invalid, run gen_sort_tags.py and commit that' else: - tmp = filename.with_suffix(f'{filename.suffix}.tmp') - et.indent(xml, space=' ') - xml.write(tmp, encoding="utf-8", xml_declaration=True) - tmp.replace(filename) + genxml.sort() + genxml.write_file() if not args.quiet: print('done.') diff --git a/src/intel/genxml/intel_genxml.py b/src/intel/genxml/intel_genxml.py index cd07828beef..492bf69d631 100755 --- a/src/intel/genxml/intel_genxml.py +++ b/src/intel/genxml/intel_genxml.py @@ -168,7 +168,8 @@ def sort_xml(xml: et.ElementTree) -> None: class GenXml(object): def __init__(self, filename): - self.et = et.parse(filename) + self.filename = pathlib.Path(filename) + self.et = et.parse(self.filename) def filter_engines(self, engines): changed = False @@ -187,3 +188,21 @@ class GenXml(object): items.append(item) if changed: self.et.getroot()[:] = items + + def sort(self): + sort_xml(self.et) + + def sorted_copy(self): + clone = copy.deepcopy(self) + clone.sort() + return clone + + def is_equivalent_xml(self, other): + return all(node_validator(old, new) + for old, new in zip(self.et.getroot(), other.et.getroot())) + + def write_file(self): + tmp = self.filename.with_suffix(f'{self.filename.suffix}.tmp') + et.indent(self.et, space=' ') + self.et.write(tmp, encoding="utf-8", xml_declaration=True) + tmp.replace(self.filename)