summaryrefslogtreecommitdiff
path: root/cdsapi/api.py
blob: b840cee6060679abecefac7b9aa6422ced88129b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import requests
import json
import time
import datetime


def bytes_to_string(n):
    u = ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi']
    i = 0
    while n >= 1024:
        n /= 1024.0
        i += 1
    return "%g%s" % (int(n * 10 + 0.5) / 10.0, u[i])


class Client(object):

    def __init__(self, end_point, user_id, api_key, verbose=False, verify=True, timeout=None, full_stack=False):
        self.end_point = end_point
        self.user_id = user_id
        self.api_key = api_key
        self.verbose = verbose
        self.verify = verify
        self.timeout = timeout
        self.sleep_max = 120
        self.full_stack = full_stack

    def get_resource(self, name, request, target=None):
        self._api("%s/resources/%s" % (self.end_point, name), request, target)

    def _download(self, url, local_filename=None):

        if local_filename is None:
            local_filename = url.split('/')[-1]

        r = requests.get(url, stream=True, verify=self.verify)
        with open(local_filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=1024):
                if chunk:
                    f.write(chunk)

        return local_filename

    def _api(self, url, request, target):

        session = requests.Session()
        session.auth = (str(self.user_id), str(self.api_key))

        self._trace("POST %s %s" % (url, json.dumps(request)))
        result = session.post(url, json=request, verify=self.verify)
        result.raise_for_status()

        try:
        	reply = result.json()
        except:
        	raise Exception(result.text)

        try:
            result.raise_for_status()
        except Exception:
            if 'message' in reply:
                error = reply['message']

                if 'context' in reply and 'required_terms' in reply['context']:
                    e = [error]
                    for t in reply['context']['required_terms']:
                        e.append("To access this resource, you first need to accept the terms of '%s' at %s" % (t['title'], t['url']))
                    error = '. '.join(e)
                raise Exception(error)
            else:
                raise

        sleep = 1
        start = time.time()

        while True:

            self._trace(reply)

            if reply['state'] == 'completed':

                if self.target:
                    self._download(reply['location'], int(reply['content_length']), target, verify=self.verify)
                else:
                    metadata = session.head(reply['location'], verify=self.verify)
                    metadata.raise_for_status()
                    print(metadata)

                self._trace("Done")
                return

            if reply['state'] in ('queued', 'running'):
                rid = reply['request_id']

                if self.timeout and (time.time() - start > self.timeout):
                    raise Exception("TIMEOUT")

                time.sleep(sleep)
                self._trace("Request ID is %s, sleep %s" % (rid, sleep))
                sleep *= 1.5
                if sleep > self.sleep_max:
                    sleep = self.sleep_max

                result = session.get("%s/tasks/%s" % (self.end_point, rid), verify=self.verify)
                result.raise_for_status()
                reply = result.json()
                continue

            if reply['state'] in ('failed',):
                print("Message: %s" % (reply['error'].get("message"),))
                print("Reason:  %s" % (reply['error'].get("reason"),))
                for n in reply['error']['context']['traceback'].split('\n'):
                    if n.strip() == '' and not self.full_stack:
                        break
                    print("  %s" % (n,))
                raise Exception(reply['error'].get("reason"),)

            raise Exception("Unknown API state [%s]" % (reply['state'],))

    def _trace(self, what):
        if isinstance(what, (dict, list)):
            what = json.dumps(what, indent=4, sort_keys=True)

        ts = "{:%Y-%m-%d %H:%M:%S}".format(datetime.datetime.now())
        print('CDS-API %s %s' % (ts, what))