Skip to content

Commit

Permalink
Make the functions work with cloudpathlib also
Browse files Browse the repository at this point in the history
  • Loading branch information
remi-braun committed Dec 13, 2024
1 parent 8b95aac commit 3919037
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 49 deletions.
13 changes: 9 additions & 4 deletions CI/SCRIPTS/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,18 @@ class Polarization(ListEnum):

def get_s3_ci_path():
"""Get S3 CI path"""
# unistra.define_s3_client()

from sertit.unistra import UNISTRA_S3_ENDPOINT

return AnyPath(
"s3://sertit-sertit-utils-ci", endpoint_url=f"https://{UNISTRA_S3_ENDPOINT}"
)
try:
ci_path = AnyPath(
"s3://sertit-sertit-utils-ci", endpoint_url=f"https://{UNISTRA_S3_ENDPOINT}"
)
except TypeError:
unistra.define_s3_client()
ci_path = AnyPath("s3://sertit-sertit-utils-ci")

return ci_path


def get_proj_path():
Expand Down
12 changes: 10 additions & 2 deletions CI/SCRIPTS/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@
from typing import Union

import numpy as np
from cloudpathlib import CloudPath
from upath import UPath

from sertit import AnyPath
from sertit.types import AnyPathType, is_iterable, make_iterable

try:
from upath import UPath
except ImportError:
UPath = None

try:
from cloudpathlib import CloudPath
except ImportError:
CloudPath = None


def test_types():
"""Test some type aliases"""
Expand Down
5 changes: 4 additions & 1 deletion CI/SCRIPTS/test_unistra.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def test_unistra_s3():
assert with_s3() == 1

# Test get_geodatastore with s3
assert str(get_geodatastore()) == "s3://sertit-geodatastore/"
try:
assert str(get_geodatastore()) == "s3://sertit-geodatastore/"
except AssertionError:
assert str(get_geodatastore()) == "s3://sertit-geodatastore"

# Test get_geodatastore without s3
with tempenv.TemporaryEnvironment({s3.USE_S3_STORAGE: "0"}):
Expand Down
4 changes: 2 additions & 2 deletions sertit/archives.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ def archive(
tmp_dir.cleanup()

try:
arch = AnyPath(archive_fn, folder_path.storage_options)
except Exception:
arch = AnyPath(archive_fn, storage_options=folder_path.storage_options)
except AttributeError:
arch = AnyPath(archive_fn)

return arch
Expand Down
7 changes: 5 additions & 2 deletions sertit/ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,14 @@ def assert_dir_equal(path_1: AnyPathStrType, path_2: AnyPathStrType) -> None:
assert path_1.is_dir(), f"{path_1} is not a directory!"
assert path_2.is_dir(), f"{path_2} is not a directory!"

with tempfile.TemporaryDirectory() as tmpdir:
with (
tempfile.TemporaryDirectory() as tmpdir,
tempfile.TemporaryDirectory() as tmpdir2,
):
if path.is_cloud_path(path_1):
path_1 = s3.download(path_1, tmpdir)
if path.is_cloud_path(path_2):
path_2 = s3.download(path_2, tmpdir)
path_2 = s3.download(path_2, tmpdir2)

dcmp = filecmp.dircmp(path_1, path_2)
try:
Expand Down
22 changes: 16 additions & 6 deletions sertit/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
"""Tools for paths"""

import contextlib
import errno
import logging
import os
Expand Down Expand Up @@ -420,7 +421,7 @@ def is_cloud_path(path: AnyPathStrType):
"gs",
"gcs",
]
except ImportError:
except AttributeError:
try:
from cloudpathlib import CloudPath

Expand All @@ -431,17 +432,26 @@ def is_cloud_path(path: AnyPathStrType):

def is_path(path: Any) -> bool:
"""
Determine whether the path corresponds to a file stored on the cloud or not.
Determine whether the path is really a path or not: either str, Path, UPath or CloudPath
Args:
path (AnyPathStrType): File path
Returns:
bool: True if the file is store on the cloud.
bool: True if the file is a path
"""
from pathlib import Path

from cloudpathlib import CloudPath
from upath import UPath
is_path = isinstance(path, (str, Path))

with contextlib.suppress(ImportError):
from upath import UPath

is_path = is_path or isinstance(path, UPath)

with contextlib.suppress(ImportError):
from cloudpathlib import CloudPath

is_path = is_path or isinstance(path, CloudPath)

return isinstance(path, (str, Path, CloudPath, UPath))
return is_path
58 changes: 32 additions & 26 deletions sertit/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
S3 tools
"""

import contextlib
import logging
import os
from contextlib import contextmanager
Expand Down Expand Up @@ -279,34 +280,39 @@ def download(src, dst):
# By default, use the src path
downloaded_path = src

if path.is_path(src):
from cloudpathlib import CloudPath
from upath import UPath

# Universal pathlib
if isinstance(src, UPath):
import shutil

dst = AnyPath(dst)
if dst.is_dir() and src.name != dst.name:
downloaded_path = dst / src.name
else:
downloaded_path = dst

if src.is_file():
with src.open("rb") as f0, downloaded_path.open("wb") as f1:
shutil.copyfileobj(f0, f1)
else:
for f in src.glob("**"):
dst_file = downloaded_path / f.name
if f.is_file():
dst_file.parent.mkdir(parents=True, exist_ok=True)
with f.open("rb") as f0, dst_file.open("wb") as f1:
shutil.copyfileobj(f0, f1)
# Universal pathlib
if path.is_cloud_path(src):
import shutil

with contextlib.suppress(ImportError):
from upath import UPath

if isinstance(src, UPath):
dst = AnyPath(dst)
if dst.is_dir() and src.name != dst.name:
downloaded_path = dst / src.name
else:
downloaded_path = dst

if src.is_file():
with src.open("rb") as f0, downloaded_path.open("wb") as f1:
shutil.copyfileobj(f0, f1)
else:
downloaded_path.parent.mkdir(parents=True, exist_ok=True)

for f in src.glob("**"):
dst_file = downloaded_path / f.name
if f.is_file():
dst_file.parent.mkdir(parents=True, exist_ok=True)
with f.open("rb") as f0, dst_file.open("wb") as f1:
shutil.copyfileobj(f0, f1)

# cloudpathlib
elif isinstance(src, CloudPath):
downloaded_path = src.fspath if dst is None else src.download_to(dst)
with contextlib.suppress(ImportError):
from cloudpathlib import CloudPath

if isinstance(src, CloudPath):
downloaded_path = src.fspath if dst is None else src.download_to(dst)

return downloaded_path

Expand Down
12 changes: 10 additions & 2 deletions sertit/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,18 @@
import geopandas as gpd
import numpy as np
import xarray as xr
from cloudpathlib import CloudPath
from rasterio.io import DatasetReader, DatasetWriter
from shapely import MultiPolygon, Polygon
from upath import UPath

try:
from upath import UPath
except ImportError:
UPath = None

try:
from cloudpathlib import CloudPath
except ImportError:
CloudPath = None

AnyPathType = Union[CloudPath, Path, UPath]
"""Any Path Type (derived from Pathlib, Universal Pathlib and CloudpathLib)"""
Expand Down
6 changes: 4 additions & 2 deletions sertit/vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,10 @@ def read(
split_vect = str(vector_path).split("!")
archive_regex = ".*{}".format(split_vect[1].replace(".", r"\."))
try:
vector_path = AnyPath(split_vect[0], **vector_path.storage_options)
except Exception:
vector_path = AnyPath(
split_vect[0], storage_options=vector_path.storage_options
)
except AttributeError:
# Cloudpathlib
vector_path = AnyPath(split_vect[0])

Expand Down
4 changes: 2 additions & 2 deletions sertit/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def read_archive(
archive_base_path = archive_base_path[5:]

# For UPath
with contextlib.suppress(Exception):
with contextlib.suppress(AttributeError):
archive_base_path = AnyPath(
archive_base_path, **archive_path.storage_options
archive_base_path, storage_options=archive_path.storage_options
)
else:
archive_base_path = archive_path
Expand Down

0 comments on commit 3919037

Please sign in to comment.