diff options
author | Baudouin Raoult <baudouin.raoult@ecmwf.int> | 2018-05-31 17:11:29 +0100 |
---|---|---|
committer | Baudouin Raoult <baudouin.raoult@ecmwf.int> | 2018-05-31 17:11:29 +0100 |
commit | 661c22484cf9c6f67f300605c90e97dfe0e1001f (patch) | |
tree | 6bd4c330e7ffbe38378cdc863eb7136091214c84 /cdsapi | |
parent | 0e95eb43a873f461e1e55fc490e581ec8c3d3275 (diff) |
split download
Diffstat (limited to 'cdsapi')
-rw-r--r-- | cdsapi/api.py | 173 |
1 files changed, 122 insertions, 51 deletions
diff --git a/cdsapi/api.py b/cdsapi/api.py index adc1b6d..c942612 100644 --- a/cdsapi/api.py +++ b/cdsapi/api.py @@ -36,6 +36,117 @@ def read_config(path): return config +class Result(object): + + def __init__(self, client, reply): + + self.reply = reply + + self._url = client.url + + self.session = client.session + self.robust = client.robust + self.verify = client.verify + self.cleanup = client.delete + + self.debug = client.debug + self.info = client.info + self.warning = client.warning + self.error = client.error + + self._deleted = False + + def _download(self, url, size, target): + + if target is None: + target = url.split('/')[-1] + + self.info("Downloading %s to %s (%s)", url, target, bytes_to_string(size)) + start = time.time() + + with self.robust(requests.get)(url, stream=True, verify=self.verify) as r: + r.raise_for_status() + + total = 0 + with open(target, 'wb') as f: + for chunk in r.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + total += len(chunk) + + assert total == size + + elapsed = time.time() - start + if elapsed: + self.info("Download rate %s/s", bytes_to_string(size / elapsed)) + + return target + + def download(self, target=None): + return self._download(self.location, + self.content_length, + target) + + @property + def content_length(self): + return int(self.reply['content_length']) + + @property + def location(self): + return self.reply['location'] + + @property + def content_type(self): + return self.reply['content_type'] + + def __repr__(self): + return "Result(content_length=%s,content_type=%s,location=%s)" % (self.content_length, + self.content_type, + self.location) + + def __enter__(self): + + r = self.robust(requests.get)(self.location, stream=True, verify=self.verify) + r.raise_for_status() + + def check(self): + self.debug("HEAD %s", self.reply['location']) + metadata = self.robust(self.session.head)(self.reply['location'], + verify=self.verify) + metadata.raise_for_status() + self.debug(metadata.headers) + return metadata + + def delete(self): + + if self._deleted: + return + + if 'request_id' in self.reply: + rid = self.reply['request_id'] + + task_url = '%s/tasks/%s' % (self._url, rid) + self.debug("DELETE %s", task_url) + + delete = self.session.delete(task_url, verify=self.verify) + self.debug("DELETE returns %s %s", delete.status_code, delete.reason) + + try: + delete.raise_for_status() + except Exception: + self.warning("DELETE %s returns %s %s", + task_url, delete.status_code, delete.reason) + + self._deleted = True + + def __del__(self): + try: + if self.cleanup: + self.delete() + except Exception as e: + print(e) + + class Client(object): logger = logging.getLogger('cdsapi') @@ -48,7 +159,7 @@ class Client(object): verify=None, timeout=None, full_stack=False, - delete=False, + delete=True, retry_max=500, sleep_max=120, info_callback=None, @@ -102,6 +213,9 @@ class Client(object): self.info_callback = info_callback self.error_callback = error_callback + self.session = requests.Session() + self.session.auth = tuple(self.key.split(':', 2)) + self.debug("CDSAPI %s", dict(url=self.url, key=self.key, quiet=self.quiet, @@ -114,34 +228,14 @@ class Client(object): )) def retrieve(self, name, request, target=None): - self._api('%s/resources/%s' % (self.url, name), request, target) + result = self._api('%s/resources/%s' % (self.url, name), request) + if target is not None: + result.download(target) + return result - def _download(self, url, size, local_filename=None): + def _api(self, url, request): - if local_filename is None: - local_filename = url.split('/')[-1] - - self.info("Downloading %s to %s (%s)", url, local_filename, bytes_to_string(size)) - start = time.time() - r = self.robust(requests.get)(url, stream=True, verify=self.verify) - total = 0 - with open(local_filename, 'wb') as f: - for chunk in r.iter_content(chunk_size=1024): - if chunk: - f.write(chunk) - total += len(chunk) - - assert total == size - - elapsed = time.time() - start - if elapsed: - self.info("Download rate %s/s", bytes_to_string(size / elapsed)) - return local_filename - - def _api(self, url, request, target): - - session = requests.Session() - session.auth = tuple(self.key.split(':', 2)) + session = self.session self.info("Sending request to %s", url) self.debug("POST %s %s", url, json.dumps(request)) @@ -187,31 +281,8 @@ class Client(object): self.last_state = reply['state'] if reply['state'] == 'completed': - - if target: - self._download(reply['location'], int(reply['content_length']), target) - else: - self.debug("HEAD %s", reply['location']) - metadata = self.robust(session.head)(reply['location'], verify=self.verify) - metadata.raise_for_status() - self.debug(metadata.headers) - - if 'request_id' in reply: - rid = reply['request_id'] - - if self.delete: - task_url = '%s/tasks/%s' % (self.url, rid) - self.debug("DELETE %s", task_url) - delete = session.delete(task_url, verify=self.verify) - self.debug("DELETE returns %s %s", delete.status_code, delete.reason) - try: - delete.raise_for_status() - except Exception: - self.warning("DELETE %s returns %s %s", - task_url, delete.status_code, delete.reason) - self.debug("Done") - return + return Result(self, reply) if reply['state'] in ('queued', 'running'): rid = reply['request_id'] |