Skip to content

Commit

Permalink
Merge pull request #10 from jlowin/fix-file
Browse files Browse the repository at this point in the history
Fix issue when loading a specific file
  • Loading branch information
jlowin authored Nov 4, 2024
2 parents f4231ea + 22a5156 commit 907ba8e
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 33 deletions.
68 changes: 39 additions & 29 deletions src/copychat/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from importlib.metadata import version as get_version

from .core import (
is_glob_pattern,
resolve_paths,
scan_directory,
DiffMode,
get_file_content,
Expand Down Expand Up @@ -171,36 +169,48 @@ def main(
if not paths:
paths = ["."]

# Handle glob patterns in command line arguments
resolved_paths = []
for path in paths:
if is_glob_pattern(path):
# Use resolve_paths for glob patterns
resolved = resolve_paths([path], base_path=source_dir)
resolved_paths.extend(resolved)
else:
# Keep regular paths as-is
resolved_paths.append(
source_dir / path if source_dir != Path(".") else Path(path)
)

# Scan all resolved paths
# Handle paths
all_files = {}
for target in resolved_paths:
if target.is_file():
content = get_file_content(target, diff_mode)
if content is not None:
all_files[target] = content
for path in paths:
target = Path(path)
if target.is_absolute():
# Use absolute paths as-is
if target.is_file():
content = get_file_content(target, diff_mode)
if content is not None:
all_files[target] = content
else:
files = scan_directory(
target,
include=include.split(",") if include else None,
exclude_patterns=exclude,
diff_mode=diff_mode,
max_depth=depth,
)
all_files.update(files)
else:
files = scan_directory(
target,
include=include.split(",") if include else None,
exclude_patterns=exclude,
diff_mode=diff_mode,
max_depth=depth,
)
all_files.update(files)
# For relative paths, try both relative to current dir and source dir
targets = [Path.cwd() / path]
if source_dir != Path("."):
targets.append(source_dir / path)

for target in targets:
if target.exists():
if target.is_file():
content = get_file_content(target, diff_mode)
if content is not None:
all_files[target] = content
break
else:
files = scan_directory(
target,
include=include.split(",") if include else None,
exclude_patterns=exclude,
diff_mode=diff_mode,
max_depth=depth,
)
all_files.update(files)
break
if not all_files:
error_console.print("Found [red]0[/] matching files")
raise typer.Exit(1) # Exit with code 1 to indicate no files found
Expand Down
10 changes: 8 additions & 2 deletions src/copychat/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,15 @@ def resolve_paths(paths: list[str], base_path: Path = Path(".")) -> list[Path]:
continue
resolved.append(match)
except ValueError:
continue
# If path is not relative to base_path, just use it as-is
resolved.append(match)
else:
resolved.append(base_path / path)
# For non-glob paths, use them as-is
path_obj = Path(path)
if path_obj.is_absolute():
resolved.append(path_obj)
else:
resolved.append(base_path / path)
return resolved


Expand Down
31 changes: 29 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ def test_resolve_paths(tmp_path):

# Test glob resolution
paths = resolve_paths(["*.py", "src/**/*.py"], base_path=tmp_path)
assert len(paths) == 4
assert len(paths) == 3
assert tmp_path / "test1.py" in paths
assert tmp_path / "test2.py" in paths
assert tmp_path / "src" / "main.py" in paths

# Test mixed glob and regular paths
paths = resolve_paths(["src", "*.py"], base_path=tmp_path)
assert len(paths) == 4 # main.py will be found twice
assert len(paths) == 3
assert tmp_path / "src" in paths


Expand Down Expand Up @@ -114,3 +114,30 @@ def test_scan_with_recursive_glob(tmp_path):
) # Changed from tmp_path / "very" / "**/*.py"
assert len(subdir_files) == 1
assert any("test2.py" in str(p) for p in subdir_files)


def test_scan_single_file(tmp_path):
"""Test scanning a single file."""
# Create a test file
test_file = tmp_path / "test.py"
test_file.write_text("print('hello world')")

# Create some other files that shouldn't be included
(tmp_path / "other.py").write_text("print('other')")
(tmp_path / "test.js").write_text("console.log('test')")

# Test scanning just the single file
files = scan_directory(test_file, include=["py"])

# Should only contain our specific file
assert len(files) == 1
assert test_file in files
assert files[test_file] == "print('hello world')"

# Test with non-matching extension filter
files = scan_directory(test_file, include=["js"])
assert len(files) == 0

# Test with non-existent file
files = scan_directory(tmp_path / "nonexistent.py", include=["py"])
assert len(files) == 0

0 comments on commit 907ba8e

Please sign in to comment.