diff --git a/CHANGES.rst b/CHANGES.rst index 83cd8633de..e2c387e131 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -14,6 +14,11 @@ linelists.cdms - Add whole catalog retrieval, improve error messaging for unparseable lines, improve metadata catalog, and improve lookuptable behavior [#3173,#2901] +heasarc +^^^^^^^ + +- Fix Heasarc.download_data for Sciserver. [#3183] + jplspec ^^^^^^^ @@ -69,7 +74,6 @@ Infrastructure, Utility and Other Changes and Additions ------------------------------------------------------- - 0.4.8 (2025-01-16) ================== diff --git a/astroquery/heasarc/core.py b/astroquery/heasarc/core.py index a7229f4c1c..04986f6fb1 100644 --- a/astroquery/heasarc/core.py +++ b/astroquery/heasarc/core.py @@ -692,16 +692,17 @@ def _copy_sciserver(self, links, location='.'): Users should be using `~self.download_data` instead """ - if not (os.path.exists('/FTP/') and os.environ['HOME'].split('/')[-1] == 'idies'): + if not os.path.exists('/FTP/'): raise FileNotFoundError( 'No data archive found. This should be run on Sciserver ' 'with the data drive mounted.' ) - if not os.path.exists(location): - os.mkdir(location) + # make sure the output folder exits + os.makedirs(location, exist_ok=True) for link in links['sciserver']: + link = str(link) log.info(f'Copying to {link} from the data drive ...') if not os.path.exists(link): raise ValueError( @@ -711,7 +712,8 @@ def _copy_sciserver(self, links, location='.'): 'Heasarc Help desk' ) if os.path.isdir(link): - shutil.copytree(link, location) + download_dir = os.path.basename(link.strip('/')) + shutil.copytree(link, f'{location}/{download_dir}') else: shutil.copy(link, location) diff --git a/astroquery/heasarc/tests/test_heasarc.py b/astroquery/heasarc/tests/test_heasarc.py index 47cd7ddc14..1c63da8c8d 100644 --- a/astroquery/heasarc/tests/test_heasarc.py +++ b/astroquery/heasarc/tests/test_heasarc.py @@ -1,8 +1,8 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst import os -import shutil import pytest +import tempfile from unittest.mock import patch, PropertyMock from astropy.coordinates import SkyCoord from astropy.table import Table @@ -304,6 +304,24 @@ def test_download_data__missingcolumn(host): Heasarc.download_data(Table({"id": [1]}), host=host) +def test_download_data__sciserver(): + with tempfile.TemporaryDirectory() as tmpdir: + datadir = f'{tmpdir}/data' + downloaddir = f'{tmpdir}/download' + os.makedirs(datadir, exist_ok=True) + with open(f'{datadir}/file.txt', 'w') as fp: + fp.write('data') + # include both a file and a directory + tab = Table({'sciserver': [f'{tmpdir}/data/file.txt', f'{tmpdir}/data']}) + # The patch is to avoid the test that we are on sciserver + with patch('os.path.exists') as exists: + exists.return_value = True + Heasarc.download_data(tab, host="sciserver", location=downloaddir) + assert os.path.exists(f'{downloaddir}/file.txt') + assert os.path.exists(f'{downloaddir}/data') + assert os.path.exists(f'{downloaddir}/data/file.txt') + + def test_download_data__outside_sciserver(): with pytest.raises( FileNotFoundError, @@ -350,10 +368,10 @@ def test_s3_mock_basic(s3_mock): def test_s3_mock_file(s3_mock): links = Table({"aws": [f"s3://{s3_bucket}/{s3_key1}"]}) Heasarc.enable_cloud(profile=False) - Heasarc.download_data(links, host="aws", location=".") - file = s3_key1.split("/")[-1] - assert os.path.exists(file) - os.remove(file) + with tempfile.TemporaryDirectory() as tmpdir: + Heasarc.download_data(links, host="aws", location=tmpdir) + file = s3_key1.split("/")[-1] + assert os.path.exists(f'{tmpdir}/{file}') @pytest.mark.filterwarnings("ignore::DeprecationWarning") @@ -361,9 +379,9 @@ def test_s3_mock_file(s3_mock): def test_s3_mock_directory(s3_mock): links = Table({"aws": [f"s3://{s3_bucket}/{s3_dir}"]}) Heasarc.enable_cloud(profile=False) - Heasarc.download_data(links, host="aws", location=".") - assert os.path.exists("location") - assert os.path.exists("location/file1.txt") - assert os.path.exists("location/sub/file2.txt") - assert os.path.exists("location/sub/sub2/file3.txt") - shutil.rmtree("location") + with tempfile.TemporaryDirectory() as tmpdir: + Heasarc.download_data(links, host="aws", location=tmpdir) + assert os.path.exists(f"{tmpdir}/location") + assert os.path.exists(f"{tmpdir}/location/file1.txt") + assert os.path.exists(f"{tmpdir}/location/sub/file2.txt") + assert os.path.exists(f"{tmpdir}/location/sub/sub2/file3.txt")