diff --git a/Cargo.lock b/Cargo.lock index 3ae9744..7656517 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -16,7 +16,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "gtdb_tree" -version = "0.1.8" +version = "0.1.9" dependencies = [ "memchr", "pyo3", diff --git a/Cargo.toml b/Cargo.toml index 3121644..0fd3919 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "gtdb_tree" -version = "0.1.8" +version = "0.1.9" edition = "2021" description = "A library for parsing Newick format files, especially GTDB tree files." homepage = "https://github.com/eric9n/gtdb_tree" diff --git a/README.md b/README.md index b4f64af..5e329ac 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Add this crate to your `Cargo.toml`: ```toml [dependencies] -gtdb_tree = "0.1.0" +gtdb_tree = "0.1.9" ``` ## Usage @@ -48,3 +48,27 @@ result = gtdb_tree.parse_tree("((A:0.1,B:0.2):0.3,C:0.4);") print(result) ``` +## Advanced Usage +### Custom Node Parser +You can provide a custom parser function to handle special node formats: + +```python +import gtdb_tree + +def custom_parser(node_str): + # Custom parsing logic + name, length = node_str.split(':') + return name, 100.0, float(length) # name, bootstrap, length + +result = gtdb_tree.parse_tree("((A:0.1,B:0.2):0.3,C:0.4);", custom_parser=custom_parser) +print(result) +``` + +## Working with Node Objects +## Each Node object in the result has the following attributes: + +* id: Unique identifier for the node +* name: Name of the node +* bootstrap: Bootstrap value (if available) +* length: Branch length +* parent: ID of the parent node diff --git a/pyproject.toml b/pyproject.toml index 2184bd4..397db3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ features = ["python"] [project] name = "gtdb_tree" -version = "0.1.8" +version = "0.1.9" description = "A Python package for parsing GTDB trees using Rust" readme = "README.md" authors = [{ name = "dagou", email = "eric9n@gmail.com" }] diff --git a/setup.py b/setup.py index 1eedf6f..697dd65 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="gtdb_tree", - version="0.1.8", + version="0.1.9", rust_extensions=[RustExtension("gtdb_tree.gtdb_tree", binding=Binding.PyO3)], packages=["gtdb_tree"], # rust extensions are not zip safe, just like C-extensions. diff --git a/src/node.rs b/src/node.rs index 0b34d35..8e3599c 100644 --- a/src/node.rs +++ b/src/node.rs @@ -22,6 +22,7 @@ pub enum ParseError { UnexpectedEndOfInput, #[allow(dead_code)] InvalidFormat(String), + PythonError(String), } impl std::fmt::Display for ParseError { @@ -29,6 +30,7 @@ impl std::fmt::Display for ParseError { match self { ParseError::UnexpectedEndOfInput => write!(f, "Unexpected end of input"), ParseError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg), + ParseError::PythonError(msg) => write!(f, "Python error: {}", msg), } } } diff --git a/src/python.rs b/src/python.rs index cf0f368..d4b82d2 100644 --- a/src/python.rs +++ b/src/python.rs @@ -1,5 +1,15 @@ use crate::node::Node as RustNode; -use crate::tree; +use crate::node::ParseError; +use crate::tree::{self, NodeParser}; +use std::convert::From; +use std::sync::Arc; + +// 添加一个从 PyErr 到 ParseError 的转换实现 +impl From for ParseError { + fn from(err: PyErr) -> Self { + ParseError::PythonError(err.to_string()) + } +} #[cfg(feature = "python")] use pyo3::prelude::*; @@ -39,10 +49,94 @@ impl Node { } } +// #[cfg(feature = "python")] +// #[pyfunction] +// pub fn parse_tree(newick_str: &str) -> PyResult> { +// tree::parse_tree(newick_str) +// .map(|rust_nodes| rust_nodes.into_iter().map(|rn| Node { node: rn }).collect()) +// .map_err(|e| PyErr::new::(e.to_string())) +// } + #[cfg(feature = "python")] #[pyfunction] -pub fn parse_tree(newick_str: &str) -> PyResult> { - tree::parse_tree(newick_str) +#[pyo3(signature = (newick_str, custom_parser = None))] +#[pyo3(text_signature = "(newick_str, custom_parser=None)")] +/// Parse a Newick format string into a list of Node objects. +/// +/// This function takes a Newick format string and optionally a custom parser function, +/// and returns a list of Node objects representing the phylogenetic tree. +/// +/// Parameters: +/// ----------- +/// newick_str : str +/// The Newick format string representing the phylogenetic tree. +/// custom_parser : callable, optional +/// A custom parsing function for node information. If not provided, the default parser will be used. +/// The custom parser should have the following signature: +/// +/// def custom_parser(node_str: str) -> Tuple[str, float, float]: +/// ''' +/// Parse a node string and return name, bootstrap, and length. +/// +/// Parameters: +/// ----------- +/// node_str : str +/// The node string to parse. +/// +/// Returns: +/// -------- +/// Tuple[str, float, float] +/// A tuple containing (name, bootstrap, length) for the node. +/// ''' +/// # Your custom parsing logic here +/// return name, bootstrap, length +/// +/// Returns: +/// -------- +/// List[Node] +/// A list of Node objects representing the parsed phylogenetic tree. +/// +/// Raises: +/// ------- +/// ValueError +/// If the Newick string is invalid or parsing fails. +/// +/// Example: +/// -------- +/// >>> newick_str = "(A:0.1,B:0.2,(C:0.3,D:0.4)70:0.5);" +/// >>> nodes = parse_tree(newick_str) +/// >>> +/// >>> # Using a custom parser +/// >>> def my_parser(node_str): +/// ... parts = node_str.split(':') +/// ... name = parts[0] +/// ... length = float(parts[1]) if len(parts) > 1 else 0.0 +/// ... return name, 100.0, length # Always set bootstrap to 100.0 +/// >>> +/// >>> nodes_custom = parse_tree(newick_str, custom_parser=my_parser) +pub fn parse_tree( + _py: Python, + newick_str: &str, + custom_parser: Option, +) -> PyResult> { + let parser = match custom_parser { + Some(py_func) => { + let py_func = Arc::new(py_func); + NodeParser::Custom(Box::new( + move |node_str: &str| -> Result<(String, f64, f64), ParseError> { + Python::with_gil(|py| { + let result = py_func.call1(py, (node_str,))?; + let (name, bootstrap, length): (String, f64, f64) = result.extract(py)?; + Ok((name, bootstrap, length)) + }) + .map_err(|e: PyErr| ParseError::PythonError(e.to_string())) + }, + )) + } + None => NodeParser::Default, + }; + + tree::parse_tree(newick_str, parser) .map(|rust_nodes| rust_nodes.into_iter().map(|rn| Node { node: rn }).collect()) .map_err(|e| PyErr::new::(e.to_string())) } diff --git a/src/tree.rs b/src/tree.rs index 2991c18..f4c0228 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -1,6 +1,26 @@ use crate::node::{Node, ParseError}; use memchr::memchr2; +// 修改 NodeParser 枚举以使用 trait 对象 +pub enum NodeParser { + Default, + Custom(Box Result<(String, f64, f64), ParseError> + Send>), +} + +impl Default for NodeParser { + fn default() -> Self { + NodeParser::Default + } +} + +/// Parse the label of a node from a Newick tree string. +/// +/// This function takes a byte slice representing a node in a Newick tree string, +/// and returns the name and length of the node as a tuple. +/// +/// # Arguments +/// +/// * `label` - A string slice representing the node in a Newick tree string. fn parse_label(label: &str) -> Result<(String, f64), ParseError> { let label = label.trim_end_matches(";").trim_matches('\'').to_string(); @@ -31,27 +51,41 @@ fn parse_label(label: &str) -> Result<(String, f64), ParseError> { /// /// # Arguments /// -/// * `node_bytes` - A byte slice representing the node in a Newick tree string. +/// * `node_str` - A string slice representing the node in a Newick tree string. /// /// # Returns /// -/// Returns a `Result` containing a tuple of the name and length on success, +/// Returns a `Result` containing a tuple of the name, bootstrap, and length on success, /// or an `Err(ParseError)` on failure. /// /// # Example /// /// ``` -/// use gtdb_tree::tree::parse_node; +/// use gtdb_tree::tree::parse_node_default; /// -/// let node_bytes = b"A:0.1"; -/// let (name, bootstrap, length) = parse_node(node_bytes).unwrap(); +/// let node_str = "A:0.1"; +/// let (name, bootstrap, length) = parse_node_default(node_str).unwrap(); /// assert_eq!(name, "A"); /// assert_eq!(bootstrap, 0.0); /// assert_eq!(length, 0.1); /// ``` -pub fn parse_node(node_bytes: &[u8]) -> Result<(String, f64, f64), ParseError> { - let node_str = std::str::from_utf8(node_bytes).expect("UTF-8 sequence"); - // gtdb +pub fn parse_node_default(node_str: &str) -> Result<(String, f64, f64), ParseError> { + // 处理 "AD:0.03347[21.0]" 格式 + if let Some((name_length, bootstrap_str)) = node_str.rsplit_once('[') { + if let Some((name, length_str)) = name_length.rsplit_once(':') { + let bootstrap = bootstrap_str + .trim_end_matches(']') + .parse::() + .map_err(|_| { + ParseError::InvalidFormat(format!("Invalid bootstrap value: {}", bootstrap_str)) + })?; + let length = length_str.parse::().map_err(|_| { + ParseError::InvalidFormat(format!("Invalid length value: {}", length_str)) + })?; + return Ok((name.to_string(), bootstrap, length)); + } + } + // Check if node_str contains single quotes and ensure they are together if node_str.matches('\'').count() % 2 != 0 { return Err(ParseError::InvalidFormat(format!( @@ -102,12 +136,13 @@ pub fn parse_node(node_bytes: &[u8]) -> Result<(String, f64, f64), ParseError> { /// /// ``` /// use gtdb_tree::tree::parse_tree; +/// use gtdb_tree::tree::NodeParser; /// /// let newick_str = "((A:0.1,B:0.2):0.3,C:0.4);"; -/// let nodes = parse_tree(newick_str).unwrap(); +/// let nodes = parse_tree(newick_str, NodeParser::default()).unwrap(); /// assert_eq!(nodes.len(), 5); /// ``` -pub fn parse_tree(newick_str: &str) -> Result, ParseError> { +pub fn parse_tree(newick_str: &str, parser: NodeParser) -> Result, ParseError> { let mut nodes: Vec = Vec::new(); let mut pos = 0; @@ -132,7 +167,16 @@ pub fn parse_tree(newick_str: &str) -> Result, ParseError> { let end_pos = memchr2(b',', b')', &bytes[pos..]).unwrap_or(bytes_len - pos); let node_end_pos = pos + end_pos; let node_bytes = &bytes[pos..node_end_pos]; - let (name, bootstrap, length) = parse_node(node_bytes)?; + + let mut node_str = std::str::from_utf8(node_bytes).expect("UTF-8 sequence"); + if node_end_pos == bytes_len { + node_str = node_str.trim_end_matches(';'); + } + let (name, bootstrap, length) = match &parser { + NodeParser::Default => parse_node_default(node_str)?, + NodeParser::Custom(func) => func(node_str)?, + }; + let node_id = if &bytes[pos - 1] == &b')' { stack.pop().unwrap_or(0) } else { @@ -161,8 +205,9 @@ mod tests { use super::*; #[test] - fn test_parse_tree() { + fn test_parse_tree() -> Result<(), ParseError> { let test_cases = vec![ + "(A:0.1,B:0.2,(C:0.3,D:0.4)AD:0.03347[21.0]);", "((A:0.1,B:0.2)'56:F;H;':0.3,C:0.4);", "(,,(,));", // no nodes are named "(A,B,(C,D));", // leaf nodes are named @@ -175,15 +220,15 @@ mod tests { ]; for newick_str in test_cases { - match parse_tree(newick_str) { - Ok(nodes) => println!( - "Parsed nodes for '{}': {:?}, len: {}", - newick_str, - nodes, - nodes.len() - ), - Err(e) => println!("Error parsing '{}': {:?}", newick_str, e), - } + let nodes = parse_tree(newick_str, NodeParser::default())?; + println!( + "Parsed nodes for '{}': {:?}, len: {}", + newick_str, + nodes, + nodes.len() + ) } + + Ok(()) } }