Source code for picterra.base_client

from __future__ import annotations

import json
import logging
import os
import sys
import time
from collections.abc import Callable, Iterator

if sys.version_info >= (3, 8):
    from typing import Literal, TypedDict
else:
    from typing_extensions import Literal, TypedDict

from typing import Any, Generic, TypeVar
from urllib.parse import urlencode, urljoin

import requests
from requests.adapters import HTTPAdapter
from requests.auth import AuthBase
from urllib3.util.retry import Retry

logger = logging.getLogger()

CHUNK_SIZE_BYTES = 8192  # 8 KiB


# allow injecting an non-existing package name to test the fallback behavior
# of _get_ua in tests (see test_headers_user_agent_version__fallback)
def _get_distr_name():
    return "picterra"


def _get_ua():
    import platform

    pkg = _get_distr_name()
    if sys.version_info >= (3, 8):
        from importlib.metadata import PackageNotFoundError, version

        try:
            ver = version(pkg)
        except PackageNotFoundError:
            ver = "no_version"
    else:
        import pkg_resources  # type: ignore[import]

        try:
            ver = pkg_resources.require(pkg)[0].version
        except pkg_resources.DistributionNotFound:
            ver = "no_version"
    o_s = " ".join([os.name, platform.system(), platform.release()])
    v_info = sys.version_info
    py = "Python " + str(v_info.major) + "." + str(v_info.minor)
    return "picterra-python/%s (%s %s)" % (
        ver,
        py,
        o_s,
    )


[docs] class APIError(Exception): """Generic API error exception""" pass
class _RequestsSession(requests.Session): """ Override requests session to to implement a global session timeout """ def __init__(self, *args, **kwargs): self.timeout = kwargs.pop("timeout") super().__init__(*args, **kwargs) self.headers.update( {"User-Agent": "%s - %s" % (_get_ua(), self.headers["User-Agent"])} ) def request(self, *args, **kwargs): kwargs.setdefault("timeout", self.timeout) return super().request(*args, **kwargs) def _download_to_file(url: str, filename: str): # Given we do not use self.sess the timeout is disabled (requests default), and this # is good as file download can take a long time with requests.get(url, stream=True) as r: r.raise_for_status() with open(filename, "wb+") as f: logger.debug("Downloading to file %s.." % filename) for chunk in r.iter_content(chunk_size=CHUNK_SIZE_BYTES): if chunk: # filter out keep-alive new chunks f.write(chunk) def _upload_file_to_blobstore(upload_url: str, filename: str): if not (os.path.exists(filename) and os.path.isfile(filename)): raise ValueError("Invalid file: " + filename) with open( filename, "rb" ) as f: # binary recommended by requests stream upload (see link below) logger.debug("Opening and streaming to upload file %s" % filename) # Given we do not use self.sess the timeout is disabled (requests default), and this # is good as file upload can take a long time. Also we use requests streaming upload # (https://requests.readthedocs.io/en/latest/user/advanced/#streaming-uploads) to avoid # reading the (potentially large) layer GeoJSON in memory resp = requests.put(upload_url, data=f) if not resp.ok: logger.error("Error when uploading to blobstore %s" % upload_url) raise APIError(resp.text) def multipolygon_to_polygon_feature_collection(mp): return { "type": "FeatureCollection", "features": [ { "type": "Feature", "properties": {}, "geometry": {"type": "Polygon", "coordinates": p}, } for p in mp["coordinates"] ], } def _check_resp_is_ok(resp: requests.Response, msg: str) -> None: if not resp.ok: raise APIError( "%s (url %s, status %d): %s" % (msg, resp.url, resp.status_code, resp.text) ) T = TypeVar("T")
[docs] class ResultsPage(Generic[T]): """ Interface for a paginated response from the API Typically the endpoint returning list of objects return them splitted in pages (page 1, page 2, etc..) of a fixed dimension (eg 20). Thus each `list_XX` function returns a ResultsPage (by default the first one); once you have a ResultsPage for a given list of objects, you can: * check its length with ``len()``; - example: ``len(page)`` * access a single element with the index operator ``[]``; - example: ``page[5]`` * turn it into a list of dictionaries with ``list()``; - example: ``list(page)`` * get the next page with ``.next()``; this could return None if the list is finished; - example: ``page.next()`` You can also get a specific page passing the page number to the ``list_XX`` function """ _fetch: Callable[[str], requests.Response] _next_url: str | None _prev_url: str | None _results: list[T] _url: str def __init__(self, url: str, fetch: Callable[[str], requests.Response]): resp = fetch(url) _check_resp_is_ok(resp, "Failed to get page") r: dict[str, Any] = resp.json() next_url: str | None = r["next"] prev_url: str | None = r["previous"] results: list[T] = r["results"] self._fetch = fetch self._next_url = next_url self._prev_url = prev_url self._results = results self._url = url def next(self): return ResultsPage(self._next_url, self._fetch) if self._next_url else None def previous(self): return ResultsPage(self._prev_url, self._fetch) if self._prev_url else None def __len__(self) -> int: return len(self._results) def __getitem__(self, key: int) -> T: return self._results[key] def __iter__(self) -> Iterator[T]: return iter([self._results[i] for i in range(len(self._results))]) def __str__(self) -> str: return f"{len(self._results)} results from {self._url}"
class Feature(TypedDict): type: Literal["Feature"] properties: dict[str, Any] geometry: dict[str, Any] class FeatureCollection(TypedDict): type: Literal["FeatureCollection"] features: list[Feature] class ApiKeyAuth(AuthBase): api_key: str def __init__(self): api_key = os.environ.get("PICTERRA_API_KEY", None) if api_key is None: raise APIError("PICTERRA_API_KEY environment variable is not defined") self.api_key = api_key def __call__(self, r): r.headers["X-Api-Key"] = self.api_key return r class BaseAPIClient: """ Base class for Picterra API clients. This is subclassed for the different products we have. """ def __init__( self, api_url: str, timeout: int = 30, max_retries: int = 3, backoff_factor: int = 10, ): """ Args: api_url: the api's base url. This is different based on the Picterra product used and is typically defined by implementations of this client timeout: number of seconds before the request times out max_retries: max attempts when encountering gateway issues or throttles; see retry_strategy comment below backoff_factor: factor used nin the backoff algorithm; see retry_strategy comment below """ base_url = os.environ.get("PICTERRA_BASE_URL", "https://app.picterra.ch/") logger.info( "Using base_url=%s, api_url=%s; %d max retries, %d backoff and %s timeout.", base_url, api_url, max_retries, backoff_factor, timeout, ) self.base_url = urljoin(base_url, api_url) # Create the session with a default timeout (30 sec) and auth, that we can then # override on a per-endpoint basis (will be disabled for file uploads and downloads) self.sess = _RequestsSession(timeout=timeout) self.sess.auth = ApiKeyAuth() # Authentication # Retry: we set the HTTP codes for our throttle (429) plus possible gateway problems (50*), # and for polling methods (GET), as non-idempotent ones should be addressed via idempotency # key mechanism; given the algorithm is {<backoff_factor> * (2 **<retries-1>}, and we # default to 30s for polling and max 30 req/min, the default 5-10-20 sequence should # provide enough room for recovery retry_strategy = Retry( total=max_retries, status_forcelist=[429, 502, 503, 504], backoff_factor=backoff_factor, allowed_methods=["GET"], ) adapter = HTTPAdapter(max_retries=retry_strategy) self.sess.mount("https://", adapter) self.sess.mount("http://", adapter) def _full_url(self, path: str, params: dict[str, Any] | None = None): url = urljoin(self.base_url, path) if not params: return url else: qstr = urlencode(params) return "%s?%s" % (url, qstr) def _wait_until_operation_completes( self, operation_response: dict[str, Any] ) -> dict[str, Any]: """Polls an operation an returns its data""" operation_id = operation_response["operation_id"] poll_interval = operation_response["poll_interval"] # Just sleep for a short while the first time time.sleep(poll_interval * 0.1) while True: logger.info("Polling operation id %s" % operation_id) resp = self.sess.get( self._full_url("operations/%s/" % operation_id), ) if not resp.ok: raise APIError(resp.text) status = resp.json()["status"] logger.info("status=%s" % status) if status == "success": break if status == "failed": errors = resp.json()["errors"] raise APIError( "Operation %s failed: %s" % (operation_id, json.dumps(errors)) ) time.sleep(poll_interval) return resp.json() def _return_results_page( self, resource_endpoint: str, params: dict[str, Any] | None = None ) -> ResultsPage: if params is None: params = {} if "page_number" not in params: params["page_number"] = 1 url = self._full_url("%s/" % resource_endpoint, params=params) return ResultsPage(url, self.sess.get) def get_operation_results(self, operation_id: str) -> dict[str, Any]: """ Return the 'results' dict of an operation This a **beta** function, subject to change. Args: operation_id: The id of the operation """ resp = self.sess.get( self._full_url("operations/%s/" % operation_id), ) return resp.json()["results"]