"""Base class for downloading and processing GloFAS river discharge data."""
import logging
import time
from abc import abstractmethod
from dataclasses import dataclass
from datetime import date
from pathlib import Path
from typing import List, Tuple, Union
import numpy as np
import xarray as xr
from dateutil import rrule
from ochanticipy.config.countryconfig import CountryConfig
from ochanticipy.datasources.datasource import DataSource
from ochanticipy.utils.check_extra_imports import check_extra_imports
from ochanticipy.utils.dates import get_date_from_user_input
from ochanticipy.utils.geoboundingbox import GeoBoundingBox
_MODULE_BASENAME = "glofas"
_HYDROLOGICAL_MODEL = "lisflood"
_RIVER_DISCHARGE_VAR = "dis24"
_CDS_MAX_REQUESTS = 500
_REQUEST_SLEEP_TIME = 60 # seconds
# The GloFAS API on CDS requires coordinates have specific formats.
# For v3, this needs to be x.x5, and v4, either x.x25 or x.x75.
_GBB_ROUND_COORDS_PARAMS = {
3: {"offset_val": 0.05, "round_val": 0.1},
4: {"offset_val": 0.025, "round_val": 0.05},
}
_FILENAME_ROUNDING_PRECISION = {3: 2, 4: 3}
DEFAULT_MODEL_VERSION = 4
logger = logging.getLogger(__name__)
# putting on top level to ensure easy mocking in tests
try:
import cdsapi
except ModuleNotFoundError:
pass
[docs]
class SystemVersions(dict):
"""Class to type the allowed model versions."""
ALLOWED_KEYS = {3, 4}
def __setitem__(self, key, value):
"""Set the allowed keys."""
if key not in self.ALLOWED_KEYS:
raise KeyError(f"Key '{key}' is not allowed")
super().__setitem__(key, value)
@dataclass
class _QueryParams:
"""
Class to keep track of CDS query input and output.
Parameters
----------
filepath: Path
Full filepath of downloaded CDS file
query: dict
Output of _get_query() method, to be submitted to the CDS API
request_id: str, default = None
Request ID from CDS, only set after request has been made
downloaded: bool, default = False
Whether or not the file has yet been downloaded
"""
filepath: Path
query: dict
request_id: str = None # type: ignore # TODO: fix
downloaded: bool = False
[docs]
class Glofas(DataSource):
"""
Base class for all GloFAS data downloading and processing.
Parameters
----------
country_config : CountryConfig
Country configuration
geo_bounding_box: GeoBoundingBox
The bounding coordinates of the area that should be included
cds_name : str
The name of the dataset in CDS
model_version : int
The version of the model to use, can only be 3 or 4.
Converted to system_version for the CDS query.
product_type : str or list
Which product types from the dataset are requested
date_variable_prefix : str
Some dates require a prefix for the CDS API query
frequency : str
How to split the query (and thus files): in years, months, or days.
Depends on the maximum query size of the product
coord_names : list
Coordinate names in the xarray dataset
start_date_min: date
The minimum allowed start date
end_date_max: date, default = None
The maximum allowed end date
start_date : Union[date, str], default = None
The starting date for the dataset
end_date : Union[date, str], default = None
The ending date for the dataset
leadtime_max : int, default = None
The maximum lead time in days, for forecast or reforecast data
limit_months: List[int], default = None
Limit to specific months, required for the version 4 reforecast
"""
def __init__(
self,
country_config: CountryConfig,
geo_bounding_box: GeoBoundingBox,
cds_name: str,
model_version: int,
product_type: Union[str, List[str]],
date_variable_prefix: str,
frequency: int,
coord_names: List[str],
start_date_min: date,
end_date_max: date = None,
start_date: Union[date, str] = None,
end_date: Union[date, str] = None,
leadtime_max: int = None,
month_list: List[int] = None,
):
super().__init__(
country_config=country_config,
datasource_base_dir=_MODULE_BASENAME,
is_public=True,
)
# check that extra dependencies are installed
check_extra_imports(
libraries=["cdsapi", "cfgrib"], subpackage="glofas"
)
self._start_date, self._end_date = _set_dates(
start_date_min=start_date_min,
end_date_max=end_date_max,
start_date=start_date,
end_date=end_date,
)
self._cds_name = cds_name
self._model_version = model_version
self._system_version = self._get_system_version(model_version)
self._geo_bounding_box = geo_bounding_box.round_coords(
**_GBB_ROUND_COORDS_PARAMS[self._model_version]
)
self._product_type = product_type
self._date_variable_prefix = date_variable_prefix
self._frequency = frequency
self._coord_names = coord_names
self._leadtime_max = leadtime_max
self._forecast_type = type(self).__name__.replace("Glofas", "").lower()
self._date_range = rrule.rrule(
freq=self._frequency,
dtstart=self._start_date,
until=self._end_date,
bymonth=month_list,
)
if self._date_range.count() > _CDS_MAX_REQUESTS:
msg = (
f"Your parameters would result in "
f"{self._date_range.count()} requests, however we "
f"currently only support the CDS maximum of "
f"{_CDS_MAX_REQUESTS} at this time. Please divide your "
f"query into multiple instances."
)
raise RuntimeError(msg)
elif self._date_range.count() == 0:
raise ValueError("Date range is empty, check start and end dates")
[docs]
def download( # type: ignore
self,
clobber: bool = False,
) -> List[Path]:
"""
Download the GloFAS data by querying CDS.
The raw GloFAS data is available as a global raster in CDS. This method
downloads the raster files for the specified region of interest
and date range. The files are in GRIB format and are split up either
by day, month, or year depending on the GloFAS product.
Parameters
----------
clobber : bool, default = False
Overwrite files that were already downloaded
Returns
-------
A list paths of downloaded files
"""
msg = (
f"Downloading GloFAS {self._forecast_type} "
f"for {self._start_date} - {self._end_date}"
)
if self._leadtime_max is not None:
msg += f"and up to {self._leadtime_max} day lead time"
logger.info(msg)
# Make directory
output_directory = self._get_directory()
output_directory.mkdir(parents=True, exist_ok=True)
# Get list of files to open
query_params_list = []
for file_date in self._date_range:
output_filepath = self._get_filepath(
year=file_date.year,
month=file_date.month,
day=file_date.day,
)
if not clobber and output_filepath.exists():
continue
query_params_list.append(
_QueryParams(
filepath=output_filepath,
query=self._get_query(
year=file_date.year,
month=file_date.month,
day=file_date.day,
),
)
)
download_filepaths = self._download(
query_params_list=query_params_list
)
logger.info(
f"Downloaded {len(download_filepaths)} files to {output_directory}"
)
logger.debug(f"Files downloaded: {download_filepaths}")
return download_filepaths
[docs]
def process( # type: ignore
self,
clobber: bool = False,
) -> List[Path]:
"""
Process the downloaded GloFAS files.
For each raw GRIB file, read it in and extract the river discharge
from the reporting point coordinates specified in the configuration
file. Saves the output as a NetCDF file, where files are split by day,
month or year depending on the GloFAS product.
Parameters
----------
clobber : bool, default = False
Overwrite files that were already processed
Returns
-------
A list paths of processed files
"""
logger.info(
f"Processing GloFAS {self._forecast_type} for "
f"{self._start_date} - {self._end_date} and up to "
f"{self._leadtime_max} day lead time"
)
# Make the directory
output_directory = self._get_directory(is_processed=True)
output_directory.mkdir(parents=True, exist_ok=True)
# Get list of files to open
processed_filepaths = []
for file_date in self._date_range:
input_filepath = self._get_filepath(
year=file_date.year,
month=file_date.month,
day=file_date.day,
)
output_filepath = self._get_filepath(
year=file_date.year,
month=file_date.month,
day=file_date.day,
is_processed=True,
)
if not clobber and output_filepath.exists():
continue
logger.debug(f"Processing {input_filepath}")
ds_raw = self._load_single_file(
input_filepath=input_filepath,
filepath=output_filepath,
clobber=clobber,
)
ds_processed = self._get_reporting_point_dataset(ds=ds_raw)
# NetCDF doesn't like to overwrite files
if output_filepath.exists():
output_filepath.unlink()
ds_processed.to_netcdf(output_filepath)
processed_filepaths.append(output_filepath)
logger.debug(f"Wrote file to {output_filepath}")
logger.info(
f"Processed {len(processed_filepaths)} files to {output_directory}"
)
logger.debug(f"Files downloaded: {processed_filepaths}")
return processed_filepaths
[docs]
def load(
self,
) -> xr.Dataset:
"""
Load the processed GloFAS data as an xarray.DataSet.
Returns
-------
A single xarray dataset containing all GloFAS reporting points
and their associated river discharge
"""
filepath_list = [
self._get_filepath(
year=dataset_date.year,
month=dataset_date.month,
day=dataset_date.day,
is_processed=True,
)
for dataset_date in self._date_range
]
with xr.open_mfdataset(filepath_list) as ds:
return ds
@staticmethod
@abstractmethod
def _system_version_dict() -> SystemVersions:
"""Return a dictionary with system version strings."""
pass
@classmethod
def _get_system_version(cls, model_version: int):
try:
return cls._system_version_dict()[model_version]
except KeyError:
raise ValueError("Model version must be 3 or 4")
def _get_filepath(
self,
year: int,
month: int = None,
day: int = None,
is_processed: bool = False,
) -> Path:
"""Get downloaded / processed filepaths based on GloFAS product."""
filename = (
f"{self._country_config.iso3}_{self._cds_name}_"
f"v{self._model_version}_{year}"
)
if self._frequency in [rrule.MONTHLY, rrule.DAILY]:
filename += f"-{str(month).zfill(2)}"
if self._frequency == rrule.DAILY:
filename += f"-{str(day).zfill(2)}"
if self._leadtime_max is not None:
filename += f"_ltmax{str(self._leadtime_max).zfill(2)}d"
filename_gbb = self._geo_bounding_box.get_filename_repr(
precision=_FILENAME_ROUNDING_PRECISION[self._model_version]
)
filename += f"_{filename_gbb}"
if is_processed:
filename += "_processed.nc"
else:
filename += ".grib"
return self._get_directory(is_processed=is_processed) / Path(filename)
def _get_directory(self, is_processed: bool = False) -> Path:
"""Get download / processed directory for GloFAS product."""
return (
self._processed_base_dir
if is_processed
else self._raw_base_dir / self._cds_name
)
def _download(
self,
query_params_list: List[_QueryParams],
) -> List[Path]:
"""
Download the GloFAS data from CDS.
Uses query_params_list, which is a list of API request input dicts,
to query the CDS API, and for each query, store the request
ID that is returned, and the downloaded state.
Then loops through the list of request, checking each one to see
if it has been completed on the CDS side. If so, it's downloaded,
and then removed from the list of requests. The process continues
until the request list is empty
"""
# First make the requests to the CDS client and store request number
for query_params in query_params_list:
logger.debug(f"Making request {query_params.query}")
query_params.request_id = (
cdsapi.Client(wait_until_complete=False, delete=False)
.retrieve(name=self._cds_name, request=query_params.query)
.reply["request_id"]
)
# Loop through the request list and check status until all requests
# are downloaded
downloaded_filepaths = []
while query_params_list:
for query_params in query_params_list:
result = cdsapi.api.Result(
client=cdsapi.Client(
wait_until_complete=False, delete=False
),
reply={"request_id": query_params.request_id},
)
result.update()
state = result.reply["state"]
logger.debug(
f"For request {query_params.request_id} and filename "
f"{query_params.filepath}, state is {state}"
)
if state == "completed":
result.download(query_params.filepath)
query_params.downloaded = True
downloaded_filepaths.append(query_params.filepath)
elif state == "failed":
raise RuntimeError("Query has failed, try again")
# Remove requests that have been downloaded
query_params_list = [
query_params
for query_params in query_params_list
if not query_params.downloaded
]
# Sleep a bit before the next loop so that we're not
# hammering on cds
if query_params_list:
time.sleep(_REQUEST_SLEEP_TIME)
logger.info(f"Sleeping for {_REQUEST_SLEEP_TIME} s")
return downloaded_filepaths
def _get_query(
self,
year: int,
month: int = None,
day: int = None,
) -> dict:
"""Create dictionary for CDS API query input."""
query = {
"variable": "river_discharge_in_the_last_24_hours",
"format": "grib",
"product_type": self._product_type,
"system_version": self._system_version,
"hydrological_model": _HYDROLOGICAL_MODEL,
f"{self._date_variable_prefix}year": str(year),
f"{self._date_variable_prefix}month": str(month).zfill(2)
if self._frequency in [rrule.MONTHLY, rrule.DAILY]
else [str(x + 1).zfill(2) for x in range(12)],
f"{self._date_variable_prefix}day": str(day).zfill(2)
if self._frequency == rrule.DAILY
else [str(x + 1).zfill(2) for x in range(31)],
"area": [
self._geo_bounding_box.lat_max,
self._geo_bounding_box.lon_min,
self._geo_bounding_box.lat_min,
self._geo_bounding_box.lon_max,
],
}
if self._leadtime_max is not None:
leadtime = list(np.arange(self._leadtime_max) + 1)
query["leadtime_hour"] = [
str(single_leadtime * 24) for single_leadtime in leadtime
]
logger.debug(f"Query: {query}")
return query
@abstractmethod
def _load_single_file(self, *args, **kwargs) -> xr.Dataset:
"""Process a single raw raster file."""
pass
def _get_reporting_point_dataset(self, ds: xr.Dataset) -> xr.Dataset:
"""Convert raw raster to processed that uses reporting points."""
if self._country_config.glofas is None:
raise KeyError(
"The country configuration file does not contain "
"any reporting point coordinates. Please update the "
"configuration file and try again."
)
# Check that lat and lon of reporting points are in the bounds
for reporting_point in self._country_config.glofas.reporting_points:
if (
not ds.longitude.min()
< reporting_point.lon
< ds.longitude.max()
):
raise IndexError(
f"ReportingPoint {reporting_point.id} has out-of-bounds "
f"lon value of {reporting_point.lon} (data lon ranges "
f"from {ds.longitude.min().values} "
f"to {ds.longitude.max().values})"
)
if not ds.latitude.min() < reporting_point.lat < ds.latitude.max():
raise IndexError(
f"ReportingPoint {reporting_point.id} has out-of-bounds "
f"lat value of {reporting_point.lat} (data lat ranges "
f"from {ds.latitude.min().values} "
f"to {ds.latitude.max().values})"
)
# If reporting points fit then return processed dataset
return xr.Dataset(
data_vars={
reporting_point.name: (
self._coord_names,
ds.sel(
longitude=reporting_point.lon,
latitude=reporting_point.lat,
method="nearest",
)[_RIVER_DISCHARGE_VAR].data,
)
# fmt: off
for reporting_point in
self._country_config.glofas.reporting_points
# fmt: on
},
coords={
coord_name: ds[coord_name] for coord_name in self._coord_names
},
)
@staticmethod
def _preprocess_load(ds: xr.Dataset) -> xr.Dataset:
"""Preprocessing to do before loading the processed data."""
return ds
def _set_dates(
start_date_min: date,
end_date_max: date = None,
start_date: Union[date, str] = None,
end_date: Union[date, str] = None,
) -> Tuple[date, date]:
"""Adjust date types and check against limits."""
# TODO: perhaps this is general enough to be useful for other data sources?
# Set any missing dates
if end_date_max is None:
end_date_max = date.today()
if start_date is None:
start_date = start_date_min
if end_date is None:
end_date = end_date_max
# Make sure that dates are all date type
start_date = get_date_from_user_input(start_date)
end_date = get_date_from_user_input(end_date)
# Check against bounds
if start_date > end_date:
raise ValueError(
"Please ensure that the start date is <= the end_date"
)
if start_date < start_date_min:
logger.warning(
f"Start date {start_date} is too far in the past,"
f"setting to {start_date_min}"
)
start_date = start_date_min
if end_date > end_date_max:
logger.warning(
f"End date {end_date} is too far in the future,"
f"setting to {end_date_max}"
)
end_date = end_date_max
return start_date, end_date