-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path_base.py
229 lines (174 loc) · 7.83 KB
/
_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import functools
import logging
import os
import re
import warnings
from importlib.metadata import version
from typing import Any, Optional
import httpx
from rich.logging import RichHandler
from rich.pretty import pprint
__version__ = version("datalab-api")
__all__ = ("__version__", "BaseDatalabClient")
def pretty_displayer(method):
@functools.wraps(method)
def rich_wrapper(self, *args, **kwargs):
display = kwargs.pop("display", False)
result = method(self, *args, **kwargs)
if display:
if isinstance(result, dict) and "blocks_obj" in result:
blocks: dict[str, dict] = result["blocks_obj"]
for block in blocks.values():
if "bokeh_plot_data" in block:
bokeh_from_json(block)
pprint(result, max_length=None, max_string=100, max_depth=3)
return result
return rich_wrapper
class AutoPrettyPrint(type):
def __new__(cls, name, bases, dct):
for attr, value in dct.items():
if callable(value) and not attr.startswith("__"):
dct[attr] = pretty_displayer(value)
return super().__new__(cls, name, bases, dct)
def bokeh_from_json(block_data):
from bokeh.io import curdoc
from bokeh.plotting import show
if "bokeh_plot_data" in block_data:
bokeh_plot_data = block_data["bokeh_plot_data"]
else:
bokeh_plot_data = block_data
curdoc().replace_with_json(bokeh_plot_data["doc"])
show(curdoc().roots[0])
class BaseDatalabClient(metaclass=AutoPrettyPrint):
"""A base class that implements some of the shared/logistical functionality
(hopefully) common to all Datalab clients.
Mainly used to keep the namespace of the 'real' client classes clean and
readable by users.
"""
_api_key: Optional[str] = None
_session: Optional[httpx.Client] = None
_headers: dict[str, str] = {}
bad_server_versions: Optional[tuple[tuple[int, int, int]]] = ((0, 2, 0),)
"""Any known server versions that are not supported by this client."""
min_server_version: tuple[int, int, int] = (0, 1, 0)
"""The minimum supported server version that this client supports."""
def __init__(self, datalab_api_url: str, log_level: str = "WARNING"):
"""Creates an authenticated client.
An API key is required to authenticate requests. The client will attempt to load it from a
series of environment variables, `DATALAB_API_KEY` and prefixed versions for the given
requested instance (e.g., `PUBLIC_DATALAB_API_KEY` for the public deployment
which has prefix `public`).
Parameters:
datalab_api_url: The URL of the Datalab API.
TODO: If the URL of a datalab *UI* is provided, a request will be made to attempt
to resolve the underlying API URL (e.g., `https://public.datalab.odbx.science`
will 'redirect' to `https://public.api.odbx.science`).
log_level: The logging level to use for the client. Defaults to "WARNING".
"""
self.datalab_api_url = datalab_api_url
if not self.datalab_api_url:
raise ValueError("No Datalab API URL provided.")
if not self.datalab_api_url.startswith("http"):
self.datalab_api_url = f"https://{self.datalab_api_url}"
logging.basicConfig(level=log_level, handlers=[RichHandler()])
self.log = logging.getLogger(__name__)
self._http_client = httpx.Client
self._headers["User-Agent"] = f"Datalab Python API/{__version__}"
self._detect_api_url()
info_json = self.get_info()
self._datalab_api_versions: list[str] = info_json["data"]["attributes"][
"available_api_versions"
]
self._datalab_server_version: str = info_json["data"]["attributes"]["server_version"]
self._datalab_instance_prefix: Optional[str] = info_json["data"]["attributes"].get(
"identifier_prefix"
)
self._find_api_key()
def _detect_api_url(self) -> None:
"""Perform a handshake with the chosen URL to ascertain the correct API URL.
If a datalab UI URL is passed, the client will attempt to resolve the API URL by
inspecting the HTML meta tags.
Do not use the session for this, so we are not passing the API key to arbitrary URLs.
"""
response = httpx.get(self.datalab_api_url)
match = re.search(
r'<meta name="x_datalab_api_url" content="(.*?)">',
response.text,
re.IGNORECASE,
)
if match:
self.datalab_api_url = match.group(1)
warnings.warn(
f"Found API URL {self.datalab_api_url} in HTML meta tag. Creating client with this URL instead."
)
def get_info(self) -> dict[str, Any]:
raise NotImplementedError
@property
def session(self) -> httpx.Client:
if self._session is None:
return self._http_client(headers=self.headers)
return self._session
@property
def headers(self):
return self._headers
def _version_negotiation(self):
"""Check whether this client is expected to work with this instance.
Raises:
RuntimeError: If the server version is not supported or if no supported API versions are found.
"""
for available_api_version in sorted(self._datalab_api_versions):
major, minor, _ = (int(_) for _ in available_api_version.split("."))
if major == self.min_api_version[0] and minor == self.min_api_version[1]:
self._selected_api_version = available_api_version
break
else:
raise RuntimeError(f"No supported API versions found in {self._datalab_api_versions=}")
if self._datalab_server_version in self.bad_server_versions:
raise RuntimeError(
f"Server version {self._datalab_server_version} is not supported by this client."
)
@property
def api_key(self) -> str:
"""The API key used to authenticate requests to the Datalab API, passed
as the `DATALAB-API-KEY` HTTP header.
This can be retrieved by an authenticated user with the `/get-api-key`
endpoint of a Datalab API.
"""
if self._api_key is not None:
return self._api_key
return self._find_api_key()
def _find_api_key(self) -> str:
"""Checks various environment variables for an API key and sets the value in the
session headers.
"""
if self._api_key is None:
key_env_var = "DATALAB_API_KEY"
api_key: Optional[str] = None
# probe the prefixed environment variable first
if self._datalab_instance_prefix is not None:
api_key = os.getenv(f"{self._datalab_instance_prefix.upper()}_{key_env_var}")
if api_key is None:
api_key = os.getenv("DATALAB_API_KEY")
# Remove single and double quotes around API key if present
if api_key is not None:
api_key = api_key.strip("'").strip('"')
if api_key is None:
raise ValueError(
f"No API key found in environment variables {key_env_var}/<prefix>_{key_env_var}."
)
self._api_key = api_key
# Reset session as we are now updating the headers
if self._session is not None:
try:
self._session.close()
except Exception:
pass
finally:
self._session = None
self._headers["DATALAB-API-KEY"] = self.api_key
return self.api_key
def __enter__(self) -> "BaseDatalabClient":
return self
def __exit__(self, *_):
if self._session is not None:
self._session.close()