From 22a51560d327a14a663164b8f618489011f8a1a4 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Mon, 4 Nov 2024 16:27:43 -0500 Subject: [PATCH] Fix issue when loading a specific file --- src/copychat/cli.py | 68 +++++++++++++++++++++++++------------------- src/copychat/core.py | 10 +++++-- tests/test_core.py | 31 ++++++++++++++++++-- 3 files changed, 76 insertions(+), 33 deletions(-) diff --git a/src/copychat/cli.py b/src/copychat/cli.py index 6effae4..c8dbd7b 100644 --- a/src/copychat/cli.py +++ b/src/copychat/cli.py @@ -6,8 +6,6 @@ from enum import Enum from .core import ( - is_glob_pattern, - resolve_paths, scan_directory, DiffMode, get_file_content, @@ -160,36 +158,48 @@ def main( if not paths: paths = ["."] - # Handle glob patterns in command line arguments - resolved_paths = [] - for path in paths: - if is_glob_pattern(path): - # Use resolve_paths for glob patterns - resolved = resolve_paths([path], base_path=source_dir) - resolved_paths.extend(resolved) - else: - # Keep regular paths as-is - resolved_paths.append( - source_dir / path if source_dir != Path(".") else Path(path) - ) - - # Scan all resolved paths + # Handle paths all_files = {} - for target in resolved_paths: - if target.is_file(): - content = get_file_content(target, diff_mode) - if content is not None: - all_files[target] = content + for path in paths: + target = Path(path) + if target.is_absolute(): + # Use absolute paths as-is + if target.is_file(): + content = get_file_content(target, diff_mode) + if content is not None: + all_files[target] = content + else: + files = scan_directory( + target, + include=include.split(",") if include else None, + exclude_patterns=exclude, + diff_mode=diff_mode, + max_depth=depth, + ) + all_files.update(files) else: - files = scan_directory( - target, - include=include.split(",") if include else None, - exclude_patterns=exclude, - diff_mode=diff_mode, - max_depth=depth, - ) - all_files.update(files) + # For relative paths, try both relative to current dir and source dir + targets = [Path.cwd() / path] + if source_dir != Path("."): + targets.append(source_dir / path) + for target in targets: + if target.exists(): + if target.is_file(): + content = get_file_content(target, diff_mode) + if content is not None: + all_files[target] = content + break + else: + files = scan_directory( + target, + include=include.split(",") if include else None, + exclude_patterns=exclude, + diff_mode=diff_mode, + max_depth=depth, + ) + all_files.update(files) + break if not all_files: error_console.print("Found [red]0[/] matching files") raise typer.Exit(1) # Exit with code 1 to indicate no files found diff --git a/src/copychat/core.py b/src/copychat/core.py index caafb0c..ad4e4e0 100644 --- a/src/copychat/core.py +++ b/src/copychat/core.py @@ -41,9 +41,15 @@ def resolve_paths(paths: list[str], base_path: Path = Path(".")) -> list[Path]: continue resolved.append(match) except ValueError: - continue + # If path is not relative to base_path, just use it as-is + resolved.append(match) else: - resolved.append(base_path / path) + # For non-glob paths, use them as-is + path_obj = Path(path) + if path_obj.is_absolute(): + resolved.append(path_obj) + else: + resolved.append(base_path / path) return resolved diff --git a/tests/test_core.py b/tests/test_core.py index d650442..d983d2a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -38,14 +38,14 @@ def test_resolve_paths(tmp_path): # Test glob resolution paths = resolve_paths(["*.py", "src/**/*.py"], base_path=tmp_path) - assert len(paths) == 4 + assert len(paths) == 3 assert tmp_path / "test1.py" in paths assert tmp_path / "test2.py" in paths assert tmp_path / "src" / "main.py" in paths # Test mixed glob and regular paths paths = resolve_paths(["src", "*.py"], base_path=tmp_path) - assert len(paths) == 4 # main.py will be found twice + assert len(paths) == 3 assert tmp_path / "src" in paths @@ -114,3 +114,30 @@ def test_scan_with_recursive_glob(tmp_path): ) # Changed from tmp_path / "very" / "**/*.py" assert len(subdir_files) == 1 assert any("test2.py" in str(p) for p in subdir_files) + + +def test_scan_single_file(tmp_path): + """Test scanning a single file.""" + # Create a test file + test_file = tmp_path / "test.py" + test_file.write_text("print('hello world')") + + # Create some other files that shouldn't be included + (tmp_path / "other.py").write_text("print('other')") + (tmp_path / "test.js").write_text("console.log('test')") + + # Test scanning just the single file + files = scan_directory(test_file, include=["py"]) + + # Should only contain our specific file + assert len(files) == 1 + assert test_file in files + assert files[test_file] == "print('hello world')" + + # Test with non-matching extension filter + files = scan_directory(test_file, include=["js"]) + assert len(files) == 0 + + # Test with non-existent file + files = scan_directory(tmp_path / "nonexistent.py", include=["py"]) + assert len(files) == 0