diff --git a/pyneuroml/io.py b/pyneuroml/io.py index 295afd26..1dd55194 100644 --- a/pyneuroml/io.py +++ b/pyneuroml/io.py @@ -18,6 +18,7 @@ import lems.model.model as lems_model import neuroml.loaders as loaders import neuroml.writers as writers +from lxml import etree from neuroml import NeuroMLDocument from pyneuroml.errors import ARGUMENT_ERR, FILE_NOT_FOUND_ERR, NMLFileTypeError @@ -27,15 +28,6 @@ logger.setLevel(logging.INFO) -# extension: standard -pynml_file_type_dict = { - "xml": "LEMS", - "nml": "NeuroML", - "sedml": "SED-ML", - "sbml": "SBML", -} - - def read_neuroml2_file( nml2_file_name: str, include_includes: bool = False, @@ -256,7 +248,7 @@ def confirm_neuroml_file(filename: str, sys_error: bool = False) -> None: ) try: - confirm_file_type(filename, ["nml"]) + confirm_file_type(filename, ["nml"], "neuroml") except NMLFileTypeError as e: if filename.startswith("LEMS_"): logger.warning(error_string) @@ -276,9 +268,6 @@ def confirm_lems_file(filename: str, sys_error: bool = False) -> None: :param sys_error: toggle whether function should exit or raise exception :type sys_error: bool """ - # print('Checking file: %s'%filename) - # Some conditions to check if a LEMS file was entered - # TODO: Ideally we'd like to check the root node: checking file extensions is brittle error_string = textwrap.dedent( """ ************************************************************************************* @@ -288,7 +277,7 @@ def confirm_lems_file(filename: str, sys_error: bool = False) -> None: """ ) try: - confirm_file_type(filename, ["xml"]) + confirm_file_type(filename, ["xml"], "lems") except NMLFileTypeError as e: if filename.endswith("nml"): logger.warning(error_string) @@ -302,15 +291,22 @@ def confirm_lems_file(filename: str, sys_error: bool = False) -> None: def confirm_file_type( filename: str, file_exts: typing.List[str], + root_tag: typing.Optional[str] = None, error_str: typing.Optional[str] = None, sys_error: bool = False, ) -> None: - """Confirm that a file exists and has the necessary extension + """Confirm that a file exists and is of the provided type. + + First we rely on file extensions to test for type, since this is the + simplest way. If this test fails, we read the full file and test the root + tag if one has been provided. :param filename: filename to confirm :type filename: str :param file_exts: list of valid file extensions, without the leading dot :type file_exts: list of strings + :param root_tag: root tag for file, used if extensions do not match + :type root_tag: str :param error_str: an optional error string to print along with the thrown exception :type error_str: string (optional) @@ -320,11 +316,32 @@ def confirm_file_type( """ confirm_file_exists(filename) filename_ext = filename.split(".")[-1] - file_types = [f"{x} ({pynml_file_type_dict[x]})" for x in file_exts] - if filename_ext not in file_exts: - error_string = ( - f"Expected file extension(s): {', '.join(file_types)}; got {filename_ext}" - ) + + matched = False + + if filename_ext in file_exts: + matched = True + + got_root_tag = None + + if matched is False: + if root_tag is not None: + with open(filename) as i_file: + xml_tree = etree.parse(i_file) + tree_root = xml_tree.getroot() + got_root_tag = tree_root.tag + + if got_root_tag.lower() == root_tag.lower(): + matched = True + + if matched is False: + error_string = f"Expected file extension does not match: {', '.join(file_exts)}; got {filename_ext}." + + if root_tag is not None: + error_string += ( + f" Expected root tag does not match: {root_tag}; got {got_root_tag}" + ) + if error_str is not None: error_string += "\n" + error_str if sys_error is True: diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 00000000..3167eb8b --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +""" +Tests for io methods + +File: tests/test_io.py + +Copyright 2024 NeuroML contributors +""" + +import logging +import os + +from pyneuroml.errors import NMLFileTypeError +from pyneuroml.io import confirm_file_type + +from . import BaseTestCase + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class TestIo(BaseTestCase): + """Test io module""" + + def test_confirm_file_type(self): + """Test confirm_file_type method.""" + + # LEMS file with xml extension + test_lems_file = "a_test_lems_file.xml" + with open(test_lems_file, "w") as f: + print(" ", file=f) + confirm_file_type(test_lems_file, ["xml"]) + + # lems file with non xml extension but provided tag + test_lems_file2 = "a_test_lems_file.lems" + with open(test_lems_file2, "w") as f: + print(" ", file=f) + + confirm_file_type(test_lems_file2, ["xml"], "lems") + + # lems file with non xml and bad tag: should fail + with self.assertRaises(NMLFileTypeError): + test_lems_file3 = "a_bad_test_lems_file.lems" + with open(test_lems_file3, "w") as f: + print(" ", file=f) + + confirm_file_type(test_lems_file3, ["xml"], "lems") + + os.unlink(test_lems_file) + os.unlink(test_lems_file2) + os.unlink(test_lems_file3)