Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lib/rift/Controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def make_parser():
# sync
subprs = subparsers.add_parser('sync', help='Synchronize remote repositories')
subprs.add_argument('-o', '--output', help='Synchronization output directory')
subprs.add_argument('-m', '--max-size', type=int,
help='Max size authorized for the download of each file, in bytes')
subprs.add_argument('repositories', metavar='REPOSITORY', nargs='*',
help='repositories to synchronize (default: all)')

Expand Down Expand Up @@ -977,8 +979,10 @@ def action_sync(args, config):
# configuration parameter as default value.
if args.output:
output = args.output
max_size = args.max_size
else:
output = config.get('sync_output')
max_size = None
if output is None:
raise RiftError(
"Synchronization output directory must be defined with "
Expand Down Expand Up @@ -1019,7 +1023,8 @@ def action_sync(args, config):
repo.get('url')
)
sync['source'] = repo.get('url')
synchronizer = RepoSyncFactory.get(config, name, output, sync, arch)
synchronizer = RepoSyncFactory.get(config, name, output, sync,
max_size, arch)
if synchronizer.source in synchronized_sources:
logging.debug(
"Skipping already synchronized source %s",
Expand Down
25 changes: 13 additions & 12 deletions lib/rift/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

class RepoSyncBase:
"""Common parent to all RepoSync* classes."""
def __init__(self, config, name, output, sync, arch=None):
def __init__(self, config, name, output, sync, max_size=None, arch=None):
self.config = config
self.name = name
subdir = sync.get('subdir', '').lstrip('/')
Expand All @@ -72,6 +72,7 @@ def __init__(self, config, name, output, sync, arch=None):
)
self._logfh = None # Initialized in _log_open()
self.patterns = SyncPatterns(sync['include'], sync['exclude'])
self.max_size = max_size

@property
def base_url(self):
Expand Down Expand Up @@ -120,8 +121,8 @@ def _ensure_repo_dir(self):

class RepoSyncLftp(RepoSyncBase):
"""Synchronize remote repositories with LFTP."""
def __init__(self, config, name, output, sync, arch=None):
super().__init__(config, name, output, sync, arch)
def __init__(self, config, name, output, sync, max_size=None, arch=None):
super().__init__(config, name, output, sync, max_size, arch)
self.include_arg = ' '.join(
[f"--include={pattern}" for pattern in self.patterns.include]
)
Expand Down Expand Up @@ -163,8 +164,8 @@ class RepoSyncIndexed(RepoSyncBase):
declared in index.
"""

def __init__(self, config, name, output, sync, arch=None):
super().__init__(config, name, output, sync, arch)
def __init__(self, config, name, output, sync, max_size=None, arch=None):
super().__init__(config, name, output, sync, max_size, arch)
self.indexed_files = []

def _relpath_matches(self, relpath):
Expand Down Expand Up @@ -227,8 +228,8 @@ class RepoSyncEpel(RepoSyncIndexed):

PUB_ROOT = "/pub/epel"

def __init__(self, config, name, output, sync, arch=None):
super().__init__(config, name, output, sync, arch)
def __init__(self, config, name, output, sync, max_size=None, arch=None):
super().__init__(config, name, output, sync, max_size, arch)
self.pub_url = f"{self.base_url}{self.PUB_ROOT}"

def _process_line(self, line):
Expand Down Expand Up @@ -293,7 +294,7 @@ def _process_line(self, line):
url_file = f"{self.base_url}{abspath}"
self.log_write(f"download {url_file}")
logging.info("Downloading file %s", url_file)
download_file(url_file, output_file)
download_file(url_file, output_file, self.max_size)

def _run(self):
"""Run EPEL repository synchronization."""
Expand All @@ -303,7 +304,7 @@ def _run(self):
) as tmp_file:
filelist_url = f"{self.pub_url}/fullfiletimelist-epel"
logging.debug("Downloading EPEL files index %s", filelist_url)
download_file(filelist_url, tmp_file.name)
download_file(filelist_url, tmp_file.name, self.max_size)

# Open synchronization log file
logging.debug("Creating synchronization log file %s", self.logfile)
Expand Down Expand Up @@ -351,7 +352,7 @@ def _process_package(self, package):
url = package.remote_location()
self.log_write(f"download {url}")
logging.info("Downloading file '%s' to '%s'", url, output_directory)
download_file(url, output_file)
download_file(url, output_file, self.max_size)

def _run(self):
"""Run DNF repository synchronization."""
Expand Down Expand Up @@ -432,9 +433,9 @@ def check_valid_method(method):
)

@staticmethod
def get(config, name, output, sync, arch=None):
def get(config, name, output, sync, max_size=None, arch=None):
"""Return the concrete RepoSync* class corresponding to the method."""
RepoSyncFactory.check_valid_method(sync['method'])
return RepoSyncFactory.METHODS[sync['method']](
config, name, output, sync, arch
config, name, output, sync, max_size, arch
)
30 changes: 23 additions & 7 deletions lib/rift/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
Set of utilities used in multiple Rift modules.
"""

import logging
import os
import urllib

from datetime import datetime, timezone

from rift import RiftError
Expand All @@ -51,21 +53,35 @@ def banner(title):
"""
print(f"** {title} **")

def download_file(url, output):
def download_file(url, output, max_size=None):
"""
Download file pointed by url and save it in output path. Convert
potential urllib download errors into RiftError.
"""
try:
if max_size is not None:
with urllib.request.urlopen(url) as opened_url:
meta = opened_url.info()
if (isinstance(meta["Content-Length"], str) and
int(meta["Content-Length"]) > max_size):
logging.warning(
"'%s' has a size of '%s' bytes, larger than max size "
"'%d', skipping download",
url, meta["Content-Length"], max_size
)
return

urllib.request.urlretrieve(url, output)
except urllib.error.HTTPError as error:
raise RiftError(
Comment thread
valeriyoann marked this conversation as resolved.
f"HTTP error while downloading {url}: {str(error)}"
) from error
logging.warning(
"Got HTTP error '%s' while downloading '%s', skipping it",
str(error), url
)
except urllib.error.URLError as error:
raise RiftError(
f"URL error while downloading {url}: {str(error)}"
) from error
logging.warning(
"Got URL error '%s' while downloading '%s', skipping it",
str(error), url
)

def last_modified(url):
"""
Expand Down
22 changes: 13 additions & 9 deletions tests/TestUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
It contains several helper methods or classes like temporary file management.
"""

from collections import OrderedDict
from collections import namedtuple
from contextlib import contextmanager
import io
import jinja2
import logging
import tempfile
import unittest
import os
import pathlib as pl
import shutil
import tarfile
import tempfile
import time
import io
from collections import OrderedDict

import shutil
import jinja2
import unittest
import yaml
from collections import namedtuple
from contextlib import contextmanager

from rift.Config import Config, Staff, Modules
from rift.Mock import Mock
Expand Down Expand Up @@ -179,6 +179,10 @@ def assert_except(self, exc_cls, exc_str, callable_obj, *args, **kwargs):
else:
self.fail("%s not raised" % exc_cls.__name__)

def assert_file_exists(self, path):
if not pl.Path(path).resolve().is_file():
self.fail("File '%s' does not exist" % str(path))

class RiftProjectTestCase(RiftTestCase):
"""
RiftTestCase that setup a dummy project tree filled with minimal
Expand Down
45 changes: 44 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from io import StringIO
from unittest.mock import patch, Mock
import os

from rift import RiftError
from rift.utils import message, banner, last_modified
from rift.utils import message, banner, download_file, last_modified
from .TestUtils import RiftTestCase


Expand All @@ -22,6 +23,48 @@ def test_banner(self, mock_stdout):
banner("bar")
self.assertEqual(mock_stdout.getvalue(), "** bar **\n")

def test_download_file(self):
download_file(
"https://harbor-forge.ccc.ocre.cea.fr/ocean/repos/RPM-GPG-OLD-KEY-Ocean",
Comment thread
valeriyoann marked this conversation as resolved.
Outdated
"/tmp/blob", 40000
)
self.assert_file_exists("/tmp/blob")
os.remove("/tmp/blob")

@patch('urllib.request.urlopen')
def test_download_file_too_large(self, mock_urlopen):
mock_url = Mock()
mock_url.info.return_value = {
"Content-Length": "50"
}
mock_urlopen.return_value.__enter__.return_value = mock_url
with self.assertLogs(level='WARNING') as log:
download_file("https://test", "/tmp/blob", 20)

self.assertIn(
"WARNING:root:'https://test' has a size of '50' bytes, larger than max size "
"'20', skipping download",
log.output
)

def test_download_file_http_error(self):
with self.assertLogs(level='WARNING') as log:
download_file("https://localhost", "/tmp/blob")

self.assertIn(
"WARNING:root:Got URL error '<urlopen error [Errno 111] Connection refused>' while downloading 'https://localhost', skipping it",
log.output
)

def test_download_file_url_error(self):
with self.assertLogs(level='WARNING') as log:
download_file("blob:localhost", "/tmp/blob")

self.assertIn(
"WARNING:root:Got URL error '<urlopen error unknown url type: blob>' while downloading 'blob:localhost', skipping it",
log.output
)

@patch('urllib.request.urlopen')
def test_last_modified(self, mock_urlopen):
mock_response = Mock()
Expand Down
Loading