diff --git a/pyneuroml/io.py b/pyneuroml/io.py
index 295afd26..1dd55194 100644
--- a/pyneuroml/io.py
+++ b/pyneuroml/io.py
@@ -18,6 +18,7 @@
import lems.model.model as lems_model
import neuroml.loaders as loaders
import neuroml.writers as writers
+from lxml import etree
from neuroml import NeuroMLDocument
from pyneuroml.errors import ARGUMENT_ERR, FILE_NOT_FOUND_ERR, NMLFileTypeError
@@ -27,15 +28,6 @@
logger.setLevel(logging.INFO)
-# extension: standard
-pynml_file_type_dict = {
- "xml": "LEMS",
- "nml": "NeuroML",
- "sedml": "SED-ML",
- "sbml": "SBML",
-}
-
-
def read_neuroml2_file(
nml2_file_name: str,
include_includes: bool = False,
@@ -256,7 +248,7 @@ def confirm_neuroml_file(filename: str, sys_error: bool = False) -> None:
)
try:
- confirm_file_type(filename, ["nml"])
+ confirm_file_type(filename, ["nml"], "neuroml")
except NMLFileTypeError as e:
if filename.startswith("LEMS_"):
logger.warning(error_string)
@@ -276,9 +268,6 @@ def confirm_lems_file(filename: str, sys_error: bool = False) -> None:
:param sys_error: toggle whether function should exit or raise exception
:type sys_error: bool
"""
- # print('Checking file: %s'%filename)
- # Some conditions to check if a LEMS file was entered
- # TODO: Ideally we'd like to check the root node: checking file extensions is brittle
error_string = textwrap.dedent(
"""
*************************************************************************************
@@ -288,7 +277,7 @@ def confirm_lems_file(filename: str, sys_error: bool = False) -> None:
"""
)
try:
- confirm_file_type(filename, ["xml"])
+ confirm_file_type(filename, ["xml"], "lems")
except NMLFileTypeError as e:
if filename.endswith("nml"):
logger.warning(error_string)
@@ -302,15 +291,22 @@ def confirm_lems_file(filename: str, sys_error: bool = False) -> None:
def confirm_file_type(
filename: str,
file_exts: typing.List[str],
+ root_tag: typing.Optional[str] = None,
error_str: typing.Optional[str] = None,
sys_error: bool = False,
) -> None:
- """Confirm that a file exists and has the necessary extension
+ """Confirm that a file exists and is of the provided type.
+
+ First we rely on file extensions to test for type, since this is the
+ simplest way. If this test fails, we read the full file and test the root
+ tag if one has been provided.
:param filename: filename to confirm
:type filename: str
:param file_exts: list of valid file extensions, without the leading dot
:type file_exts: list of strings
+ :param root_tag: root tag for file, used if extensions do not match
+ :type root_tag: str
:param error_str: an optional error string to print along with the thrown
exception
:type error_str: string (optional)
@@ -320,11 +316,32 @@ def confirm_file_type(
"""
confirm_file_exists(filename)
filename_ext = filename.split(".")[-1]
- file_types = [f"{x} ({pynml_file_type_dict[x]})" for x in file_exts]
- if filename_ext not in file_exts:
- error_string = (
- f"Expected file extension(s): {', '.join(file_types)}; got {filename_ext}"
- )
+
+ matched = False
+
+ if filename_ext in file_exts:
+ matched = True
+
+ got_root_tag = None
+
+ if matched is False:
+ if root_tag is not None:
+ with open(filename) as i_file:
+ xml_tree = etree.parse(i_file)
+ tree_root = xml_tree.getroot()
+ got_root_tag = tree_root.tag
+
+ if got_root_tag.lower() == root_tag.lower():
+ matched = True
+
+ if matched is False:
+ error_string = f"Expected file extension does not match: {', '.join(file_exts)}; got {filename_ext}."
+
+ if root_tag is not None:
+ error_string += (
+ f" Expected root tag does not match: {root_tag}; got {got_root_tag}"
+ )
+
if error_str is not None:
error_string += "\n" + error_str
if sys_error is True:
diff --git a/tests/test_io.py b/tests/test_io.py
new file mode 100644
index 00000000..3167eb8b
--- /dev/null
+++ b/tests/test_io.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python3
+"""
+Tests for io methods
+
+File: tests/test_io.py
+
+Copyright 2024 NeuroML contributors
+"""
+
+import logging
+import os
+
+from pyneuroml.errors import NMLFileTypeError
+from pyneuroml.io import confirm_file_type
+
+from . import BaseTestCase
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+
+class TestIo(BaseTestCase):
+ """Test io module"""
+
+ def test_confirm_file_type(self):
+ """Test confirm_file_type method."""
+
+ # LEMS file with xml extension
+ test_lems_file = "a_test_lems_file.xml"
+ with open(test_lems_file, "w") as f:
+ print(" ", file=f)
+ confirm_file_type(test_lems_file, ["xml"])
+
+ # lems file with non xml extension but provided tag
+ test_lems_file2 = "a_test_lems_file.lems"
+ with open(test_lems_file2, "w") as f:
+ print(" ", file=f)
+
+ confirm_file_type(test_lems_file2, ["xml"], "lems")
+
+ # lems file with non xml and bad tag: should fail
+ with self.assertRaises(NMLFileTypeError):
+ test_lems_file3 = "a_bad_test_lems_file.lems"
+ with open(test_lems_file3, "w") as f:
+ print(" ", file=f)
+
+ confirm_file_type(test_lems_file3, ["xml"], "lems")
+
+ os.unlink(test_lems_file)
+ os.unlink(test_lems_file2)
+ os.unlink(test_lems_file3)