Source code for climate_diagnostics.plots.plot

import warnings
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from cartopy.feature import NaturalEarthFeature
from dask.diagnostics import ProgressBar

from ..utils.data_utils import get_coord_name, filter_by_season
from ..utils.dask_utils import get_or_create_dask_client
from ..utils.plot_utils import get_projection


[docs] @xr.register_dataset_accessor("climate_plots") class PlotsAccessor: """ A custom xarray accessor for creating climate-specific visualizations. This accessor extends xarray Dataset and DataArray objects with a `.climate_plots` namespace, providing a suite of plotting methods for common climate diagnostics. These methods simplify the process of selecting data, calculating indices, and generating publication-quality spatial plots. The accessor handles common data-wrangling tasks such as: - Finding coordinate names (e.g., 'lat', 'latitude'). - Subsetting data by space, time, and vertical level. - Applying seasonal filters. - Calculating standard climate indices (e.g., Rx5day, CWD). - Generating descriptive titles and labels. Examples -------- >>> import xarray as xr >>> # Assuming 'climate_diagnostics' is imported to register the accessor >>> import climate_diagnostics >>> >>> # Load a dataset >>> ds = xr.tutorial.load_dataset("air_temperature") >>> >>> # Generate a plot of the mean air temperature for a specific time range >>> ds.climate_plots.plot_mean( ... variable='air', ... time_range=slice('2013-05', '2013-09'), ... season='jja' ... ) """ # -------------------------------------------------------------------------- # INITIALIZATION # --------------------------------------------------------------------------
[docs] def __init__(self, xarray_obj): """Initialize the accessor with a Dataset object.""" # Store the xarray object (Dataset or DataArray) for later use. self._obj = xarray_obj
# -------------------------------------------------------------------------- # INTERNAL HELPER METHODS: DATA SELECTION & PREPARATION # -------------------------------------------------------------------------- def _select_data(self, variable, latitude=None, longitude=None, level=None, time_range=None): """ Select and subset data variable based on spatial, temporal, and vertical coordinates. This is a core utility method that handles the complex logic of coordinate selection across different climate datasets. It accommodates various coordinate naming conventions and data structures commonly found in climate model output. Key features: - Automatic coordinate name detection (lat/latitude, lon/longitude, etc.) - Flexible selection methods (single values, slices, lists) - Intelligent level handling with nearest-neighbor selection - Comprehensive error handling and validation - Support for both datetime and numeric time coordinates Parameters ---------- variable : str The name of the data variable to select from the Dataset. Must exist in dataset.data_vars. latitude : float, slice, or list, optional Latitude selection. Can be: - Single value: nearest-neighbor selection - Slice: range selection (e.g., slice(30, 60)) - List: specific values selection longitude : float, slice, or list, optional Longitude selection. Same formats as latitude. level : float or slice, optional Vertical level selection. If not specified and multiple levels exist, defaults to first level. Single values use nearest-neighbor matching. time_range : slice, optional Time range selection as slice of datetime-like objects or strings. E.g., slice('2000-01-01', '2010-12-31') Returns ------- selected_data : xr.DataArray The selected and subsetted data variable with applied selections. level_dim_name_found : str or None Name of the level dimension found in the data (for reference). level_op : str or None Description of level operation performed: - 'single_selected': Single level chosen by user - 'range_selected': Level range selected - 'single_selected_default': First level chosen automatically Returns ------- selected_data : xr.DataArray The selected and subsetted data variable. level_dim_name_found : str or None The name of the level dimension found in the data. level_op : str or None A string indicating the operation performed on the level dimension ('single_selected', 'range_selected', 'single_selected_default'). Raises ------ ValueError If the variable is not found or if coordinate selections are invalid. Notes ----- This method implements robust coordinate validation that handles: - Different coordinate naming conventions (CF-compliant and others) - Datetime vs numeric coordinate systems - Boundary checking with informative error messages - Graceful handling of missing coordinates """ # --- Step 1: Variable validation and initialization --- # Ensure the requested variable exists in the dataset. if variable not in self._obj.data_vars: raise ValueError(f"Variable '{variable}' not found. Available: {list(self._obj.data_vars.keys())}") data_var = self._obj[variable] selection_dict = {} # Stores coordinate selections for xarray's .sel() method. method_dict = {} # Stores method specifications (e.g., 'nearest' for exact matching). # --- Step 2: Coordinate name mapping and discovery --- # Build a flexible mapping from standard names (e.g., 'latitude') to the # actual coordinate names present in the dataset. This handles various # naming conventions across different climate datasets (CF-compliant and others). coord_map = { 'latitude': get_coord_name(data_var, ['lat', 'latitude', 'LAT', 'LATITUDE', 'y', 'rlat', 'nav_lat']), 'longitude': get_coord_name(data_var, ['lon', 'longitude', 'LON', 'LONGITUDE', 'x', 'rlon', 'nav_lon']), 'time': get_coord_name(data_var, ['time', 't']), 'level': next((name for name in ['level', 'lev', 'plev', 'height', 'altitude', 'depth', 'z'] if name in data_var.dims or name in data_var.coords), None) } level_dim_name_found = coord_map['level'] level_op = None # Track level operations for metadata # Map user parameters to coordinate names and datetime handling flags coord_params_map = { 'latitude': (latitude, False), 'longitude': (longitude, False), 'time': (time_range, True), 'level': (level, False) } # --- Step 3: Process each coordinate selection with comprehensive validation --- for coord_type_name, (coord_val_param, is_param_datetime_intent) in coord_params_map.items(): actual_coord_name_in_data = coord_map[coord_type_name] if coord_val_param is None: continue # Skip if no selection was provided for this coordinate. if actual_coord_name_in_data is None: if coord_type_name == "level": print(f"Warning: Level parameter provided, but no recognized level coordinate found. Ignoring.") else: raise ValueError(f"No {coord_type_name} coordinate found, but '{coord_type_name}' parameter was provided.") continue if actual_coord_name_in_data not in data_var.coords: print(f"Warning: Coord '{actual_coord_name_in_data}' not in variable '{variable}'. Skipping selection.") continue # Get the min/max values from the data for validation. min_data_val_raw_item = data_var[actual_coord_name_in_data].min().item() max_data_val_raw_item = data_var[actual_coord_name_in_data].max().item() # Extract min/max from the user's request. req_min_val, req_max_val = None, None is_scalar_request = not isinstance(coord_val_param, (slice, list, np.ndarray)) if isinstance(coord_val_param, slice): req_min_val, req_max_val = coord_val_param.start, coord_val_param.stop elif isinstance(coord_val_param, (list, np.ndarray)): if not coord_val_param: raise ValueError(f"{coord_type_name.capitalize()} selection list/array empty.") req_min_val, req_max_val = min(coord_val_param), max(coord_val_param) else: req_min_val = req_max_val = coord_val_param # This section handles the complexity of comparing user-provided coordinate # values with the data's coordinate values, especially when dealing with # different datetime representations (numpy.datetime64, cftime, numeric years). comp_req_min, comp_req_max = req_min_val, req_max_val comp_data_min, comp_data_max = min_data_val_raw_item, max_data_val_raw_item data_coord_dtype = data_var[actual_coord_name_in_data].dtype data_coord_is_np_datetime = np.issubdtype(data_coord_dtype, np.datetime64) data_coord_is_cftime = False if not data_coord_is_np_datetime and data_var[actual_coord_name_in_data].size > 0: first_val = data_var[actual_coord_name_in_data].isel({data_var[actual_coord_name_in_data].dims[0]: 0}).item() if hasattr(first_val, 'year') and hasattr(first_val, 'month') and not isinstance(first_val, (np.datetime64, np.timedelta64)): data_coord_is_cftime = True data_coord_is_datetime_like = data_coord_is_np_datetime or data_coord_is_cftime data_coord_is_numeric = np.issubdtype(data_coord_dtype, np.number) if is_param_datetime_intent: try: if comp_req_min is not None: comp_req_min = np.datetime64(comp_req_min) if comp_req_max is not None: comp_req_max = np.datetime64(comp_req_max) except Exception as e: raise ValueError(f"Could not convert requested datetime value for {coord_type_name} ('{coord_val_param}') to np.datetime64: {e}") if data_coord_is_np_datetime: if isinstance(min_data_val_raw_item, (int, np.integer)): unit = np.datetime_data(data_coord_dtype)[0] comp_data_min = np.datetime64(min_data_val_raw_item, unit) elif min_data_val_raw_item is not None: comp_data_min = np.datetime64(min_data_val_raw_item) if isinstance(max_data_val_raw_item, (int, np.integer)): unit = np.datetime_data(data_coord_dtype)[0] comp_data_max = np.datetime64(max_data_val_raw_item, unit) elif max_data_val_raw_item is not None: comp_data_max = np.datetime64(max_data_val_raw_item) elif data_coord_is_cftime: try: if comp_data_min is not None: comp_data_min = np.datetime64(comp_data_min) if comp_data_max is not None: comp_data_max = np.datetime64(comp_data_max) except Exception as c_e: print(f"Warning: Could not convert cftime data bounds for {actual_coord_name_in_data} to np.datetime64 for comparison ({c_e}). Trusting xarray's .sel().") comp_data_min, comp_data_max = None, None elif data_coord_is_numeric: print(f"Note: Time parameter is datetime-like, but data coord '{actual_coord_name_in_data}' is numeric. " "Extracting year from request for comparison.") try: if comp_req_min is not None: comp_req_min = comp_req_min.astype('datetime64[Y]').astype(int) + 1970 if comp_req_max is not None: comp_req_max = comp_req_max.astype('datetime64[Y]').astype(int) + 1970 except Exception as e_year: print(f"Warning: Could not extract year from datetime request for {actual_coord_name_in_data}: {e_year}") # Check that requested min/max values are within the data's actual min/max range. if comp_req_min is not None and comp_data_max is not None: try: if comp_req_min > comp_data_max: raise ValueError(f"Requested {coord_type_name} minimum {req_min_val} (as {comp_req_min} type {type(comp_req_min).__name__}) " f"> data maximum {max_data_val_raw_item} (as {comp_data_max} type {type(comp_data_max).__name__})") except TypeError as e: raise TypeError(f"Type mismatch comparing request min ({type(comp_req_min).__name__}) and data max ({type(comp_data_max).__name__}) for {coord_type_name}. Error: {e}") if comp_req_max is not None and comp_data_min is not None: try: if comp_req_max < comp_data_min: raise ValueError(f"Requested {coord_type_name} maximum {req_max_val} (as {comp_req_max} type {type(comp_req_max).__name__}) " f"< data minimum {min_data_val_raw_item} (as {comp_data_min} type {type(comp_data_min).__name__})") except TypeError as e: raise TypeError(f"Type mismatch comparing request max ({type(comp_req_max).__name__}) and data min ({type(comp_data_min).__name__}) for {coord_type_name}. Error: {e}") # --- Step 4: Finalize and apply selections --- selection_dict[actual_coord_name_in_data] = coord_val_param if coord_type_name == "level": level_op = 'range_selected' if is_scalar_request or isinstance(coord_val_param, (int, float, np.number)): method_dict[actual_coord_name_in_data] = 'nearest' level_op = 'single_selected' # If no level is specified but multiple exist, default to the first level. if level is None and level_dim_name_found and level_dim_name_found in data_var.dims and data_var.sizes.get(level_dim_name_found, 0) > 1: first_level_val = data_var[level_dim_name_found].isel({level_dim_name_found: 0}).item() selection_dict[level_dim_name_found] = first_level_val level_op = 'single_selected_default' print(f"Warning: Multiple levels found. Using first level: {first_level_val}") selected_data = data_var if method_dict: for coord_name, method_val in method_dict.items(): if coord_name in selection_dict: selected_data = selected_data.sel({coord_name: selection_dict[coord_name]}, method=method_val) del selection_dict[coord_name] if selection_dict: selected_data = selected_data.sel(selection_dict) if level_op == 'single_selected_default': level_op = 'single_selected' return selected_data, level_dim_name_found, level_op # -------------------------------------------------------------------------- # INTERNAL HELPER METHODS: PLOT LAYOUT & FINALIZATION # -------------------------------------------------------------------------- def _setup_geographical_ax(self, figsize, land_only, projection='PlateCarree'): """Set up the geographical axes for plotting.""" # Use the helper function to get a Cartopy projection object from a string name. proj = get_projection(projection) fig = plt.figure(figsize=figsize) ax = fig.add_subplot(1, 1, 1, projection=proj) # Add foundational geographical features for context. ax.add_feature(NaturalEarthFeature('physical', 'ocean', '50m'), zorder=0, facecolor='#D3D3D3') ax.add_feature(NaturalEarthFeature('physical', 'land', '50m'), zorder=0, edgecolor='black', facecolor='#fbfbfb') ax.add_feature(NaturalEarthFeature('physical', 'coastline', '50m'), zorder=1, edgecolor='black', facecolor='none') ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False) # Optionally mask out the ocean to focus on land areas. if land_only: ax.add_feature(NaturalEarthFeature('physical', 'ocean', '50m'), zorder=1, facecolor='white') return fig, ax def _generate_title(self, base_operation_name, var_display_name, season, level_op, level_dim_name_in_orig_var, processed_data_coords, original_var_accessor, time_info_provider_data, time_coord_name_actual, time_range_requested, time_format_type='day', index_specific_title_parts=""): """ Generate a descriptive title for a plot. The title includes information about the operation, variable, season, level, and time range. Parameters ---------- base_operation_name : str The name of the main operation being plotted (e.g., "Average"). var_display_name : str The display name of the variable. season : str The season used for the calculation. level_op : str The operation performed on the level dimension. level_dim_name_in_orig_var : str The name of the level dimension. processed_data_coords : dict Coordinates of the data being plotted. original_var_accessor : xr.DataArray The original DataArray for accessing metadata like units. time_info_provider_data : xr.DataArray DataArray used to extract the time range for the title. time_coord_name_actual : str The name of the time coordinate. time_range_requested : slice The originally requested time range. time_format_type : str, optional Format for time display ('day' or 'year'). Defaults to 'day'. index_specific_title_parts : str, optional Additional parts to add to the title, specific to a climate index. Returns ------- str The generated plot title. """ # --- Part 1: Main Title Line (Season, Operation, Variable) --- # Map the season code to a human-readable string. season_map = { 'annual': "Annual", 'djf': "Winter (DJF)", 'mam': "Spring (MAM)", 'jja': "Summer (JJA)", 'jjas': "Summer Monsoon (JJAS)", 'son': "Autumn (SON)"} season_str = season_map.get(season.lower(), season.upper()) title = f"{season_str} {base_operation_name} of {var_display_name}{index_specific_title_parts}" # --- Part 2: Level Information Sub-line --- # Add details about the vertical level if applicable. level_info_parts = [] if level_op == 'single_selected' and level_dim_name_in_orig_var and level_dim_name_in_orig_var in processed_data_coords: try: level_val = processed_data_coords[level_dim_name_in_orig_var].item() level_units = "" if level_dim_name_in_orig_var in original_var_accessor.coords: level_units = original_var_accessor.coords[level_dim_name_in_orig_var].attrs.get('units', '') level_info_parts.append(f"Level: {level_val} {level_units}".strip()) except Exception: level_info_parts.append(f"Level: {processed_data_coords.get(level_dim_name_in_orig_var, 'N/A')}") elif level_op == 'range_selected': level_info_parts.append("(Level Mean)") # --- Part 3: Time Information Sub-line --- # Add the time range of the data used in the plot. time_info_parts = [] if time_coord_name_actual and time_coord_name_actual in time_info_provider_data.coords and \ time_info_provider_data[time_coord_name_actual].size > 0: coord = time_info_provider_data[time_coord_name_actual] if np.issubdtype(coord.dtype, np.number): min_tv = coord.min().item() max_tv = coord.max().item() if min_tv == max_tv: time_info_parts.append(f"Time: {min_tv}") else: time_info_parts.append(f"{min_tv} to {max_tv}") else: try: times_np = coord.values.astype('datetime64[ns]') fmt_unit = 'datetime64[Y]' if time_format_type == 'year' else 'datetime64[D]' min_time = np.min(times_np).astype(fmt_unit) max_time = np.max(times_np).astype(fmt_unit) start_str = str(min_time) end_str = str(max_time) if start_str == end_str: time_info_parts.append(f"Time: {start_str}") else: time_info_parts.append(f"{start_str} to {end_str}") except Exception as e: print(f"Note: Could not format datetime time range for title: {e}") elif isinstance(time_range_requested, slice) and \ time_range_requested.start is not None and time_range_requested.stop is not None: if isinstance(time_range_requested.start, (int, float)) and isinstance(time_range_requested.stop, (int, float)): time_info_parts.append(f"Requested: {time_range_requested.start} to {time_range_requested.stop}") else: try: unit = 'Y' if time_format_type == 'year' else 'D' start_str = np.datetime64(time_range_requested.start, unit).astype(str) stop_str = np.datetime64(time_range_requested.stop, unit).astype(str) time_info_parts.append(f"Requested: {start_str} to {stop_str}") except Exception: time_info_parts.append(f"Requested: {time_range_requested.start} to {time_range_requested.stop}") # --- Part 4: Assemble Final Title --- # Combine all parts into a multi-line title. if level_info_parts: title += f"\n{' '.join(level_info_parts)}" if time_info_parts: title += f"\n({' '.join(time_info_parts)})" return title def _finalize_plot(self, ax, plot_object, title_str, cbar_label, data_for_extent, lon_name_plot, lat_name_plot, save_plot_path, variable): """ Finalize and optionally save the plot. This includes adding a colorbar, setting the title, adjusting the map extent, and saving the figure to a file if a path is provided. Parameters ---------- ax : cartopy.mpl.geoaxes.GeoAxes The Axes object for the plot. plot_object : matplotlib.contour.QuadContourSet or None The plot object returned by a plotting function (e.g., contourf). title_str : str The title for the plot. cbar_label : str The label for the colorbar. data_for_extent : xr.DataArray DataArray used to determine the plot's geographical extent. lon_name_plot : str Name of the longitude coordinate. lat_name_plot : str Name of the latitude coordinate. save_plot_path : str or None Path to save the plot. variable : str The name of the variable being plotted (for warning messages). Returns ------- cartopy.mpl.geoaxes.GeoAxes The finalized Axes object. """ if plot_object: plt.colorbar(plot_object, label=cbar_label, orientation='vertical', pad=0.05, shrink=0.8, ax=ax) ax.set_title(title_str, fontsize=12, loc='center') # Set the map extent to the data's boundaries. if data_for_extent[lon_name_plot].size > 0 and data_for_extent[lat_name_plot].size > 0: try: min_lon = data_for_extent[lon_name_plot].min().item() max_lon = data_for_extent[lon_name_plot].max().item() min_lat = data_for_extent[lat_name_plot].min().item() max_lat = data_for_extent[lat_name_plot].max().item() if min_lon != max_lon and min_lat != max_lat : ax.set_extent([min_lon, max_lon, min_lat, max_lat], crs=ccrs.PlateCarree()) except Exception as e: print(f"Warning: Could not set extent for '{variable}': {e}") # Save the plot to a file if a path is provided. if save_plot_path: plt.savefig(save_plot_path, bbox_inches='tight', dpi=300); print(f"Plot saved to: {save_plot_path}") return ax def _plot_spatial_data(self, processed_data_to_plot, original_variable_name, original_selected_data_attrs, original_var_accessor_for_coords, data_season_for_time_info, time_coord_name_actual, time_range_requested, level_op, level_dim_name_in_orig_var, season, contour, figsize, cmap, land_only, levels, save_plot_path, plot_operation_name, cbar_prefix="", time_format_type='day', index_specific_title_parts="", title=None, projection='PlateCarree'): """ A generic helper function for creating spatial plots. This function orchestrates the entire plotting process by calling other internal helpers. It sets up the map, generates the title, plots the data (as contours or filled contours), and finalizes the plot, serving as the core engine for all public plotting methods in this accessor. Parameters ---------- processed_data_to_plot : xr.DataArray The 2D data to plot. original_variable_name : str The name of the original variable. original_selected_data_attrs : dict Attributes of the original selected data variable. original_var_accessor_for_coords : xr.DataArray The original DataArray, used for coordinate and metadata access. data_season_for_time_info : xr.DataArray The seasonally filtered data, used to get time info for the title. time_coord_name_actual : str The actual name of the time coordinate. time_range_requested : slice The time range originally requested by the user. level_op : str The operation performed on the level dimension. level_dim_name_in_orig_var : str The name of the level dimension in the original variable. season : str The season string. contour : bool Use contour lines if True, otherwise use filled contours. figsize : tuple Figure size. cmap : str Colormap name. land_only : bool Mask oceans if True. levels : int Number of contour levels. save_plot_path : str or None Path to save the plot. plot_operation_name : str Name of the operation performed on the data (e.g., "Average"). cbar_prefix : str, optional A prefix for the colorbar label. Defaults to "". time_format_type : str, optional Time format for the title. Defaults to 'day'. index_specific_title_parts : str, optional Additional title parts for specific indices. Defaults to "". title : str, optional The title for the plot. If not provided, a descriptive title will be generated automatically. projection : str, optional The name of the cartopy projection to use. Defaults to 'PlateCarree'. Returns ------- cartopy.mpl.geoaxes.GeoAxes The Axes object of the plot. """ # Step 1: Set up the cartopy map axes. fig, ax = self._setup_geographical_ax(figsize, land_only, projection) # Step 2: Determine longitude and latitude coordinate names for plotting. lon_name = get_coord_name(processed_data_to_plot, ['lon', 'longitude', 'x', 'rlon']) lat_name = get_coord_name(processed_data_to_plot, ['lat', 'latitude', 'y', 'rlat']) if not lat_name or not lon_name: raise ValueError(f"Lat/Lon coordinates not found in processed data for '{original_variable_name}'.") # Step 3: Generate a descriptive title for the plot. if title is None: title = self._generate_title( base_operation_name=plot_operation_name, var_display_name=original_selected_data_attrs.get('long_name', original_variable_name.replace('_', ' ').capitalize()), season=season, level_op=level_op, level_dim_name_in_orig_var=level_dim_name_in_orig_var, processed_data_coords=processed_data_to_plot.coords, original_var_accessor=original_var_accessor_for_coords, time_info_provider_data=data_season_for_time_info, time_coord_name_actual=time_coord_name_actual, time_range_requested=time_range_requested, time_format_type=time_format_type, index_specific_title_parts=index_specific_title_parts ) # Step 4: Plot the data using either contour or contourf. plot_obj = None if contour: plot_obj = ax.contour( processed_data_to_plot[lon_name], processed_data_to_plot[lat_name], processed_data_to_plot, transform=ccrs.PlateCarree(), levels=levels ) else: plot_obj = ax.contourf( processed_data_to_plot[lon_name], processed_data_to_plot[lat_name], processed_data_to_plot, transform=ccrs.PlateCarree(), levels=levels ) # Step 5: Finalize the plot with a colorbar, title, and save if requested. cbar_label = f"{cbar_prefix}{original_selected_data_attrs.get('units', '')}".strip() self._finalize_plot(ax, plot_obj, title, cbar_label, processed_data_to_plot, lon_name, lat_name, save_plot_path, original_variable_name) return fig # -------------------------------------------------------------------------- # INTERNAL HELPER METHODS: CLIMATE INDEX CALCULATIONS # -------------------------------------------------------------------------- def _vectorized_consecutive_true_count(self, da, dim='time'): """ Calculate the length of consecutive `True` runs in a boolean DataArray. This is a vectorized operation that is much faster than looping. """ # Get cumulative sum of `da` along `dim`. This increments for each `True` # in a consecutive block. We reset the count when `da` is `False`. cumulative_sum = da.cumsum(dim=dim) # Where `da` is `False`, the consecutive count is 0. # We find the `cumulative_sum` just before each `False` block. # This value needs to be subtracted from the `cumulative_sum` in the next `True` block. reset_points = xr.where(da, 0, cumulative_sum).ffill(dim=dim) # Subtract the `reset_points` to get the length of each consecutive `True` run. consecutive_counts = cumulative_sum - reset_points return consecutive_counts def _apply_yearly_op_then_mean(self, data_for_yearly_op, time_coord_name, operation, op_kwargs=None, dask_op_name=""): """ Apply a yearly operation (e.g., sum, max) and then compute the mean over the years. This is a helper function used for climate indices like Rx1day. It first groups the data by year, applies an operation within each year, and then calculates the mean of these yearly results. Parameters ---------- data_for_yearly_op : xr.DataArray The input data array with a time dimension. time_coord_name : str The name of the time coordinate. operation : str The name of the operation to apply yearly (e.g., 'sum', 'max', 'mean'). op_kwargs : dict, optional Additional keyword arguments for the operation. dask_op_name : str, optional A display name for the operation when printing progress for Dask computations. Returns ------- xr.DataArray A DataArray containing the mean of the yearly operation results. """ if op_kwargs is None: op_kwargs = {} year_coord_da = data_for_yearly_op[time_coord_name].dt.year grouped_data = data_for_yearly_op.groupby(year_coord_da.rename("year_for_grouping")) if data_for_yearly_op.chunks: print(f"Computing yearly {dask_op_name or operation} for Dask...") with ProgressBar(): yearly_data = getattr(grouped_data, operation)(dim=time_coord_name, skipna=True, **op_kwargs).compute() print(f"Computing mean of yearly {dask_op_name or operation} for Dask...") with ProgressBar(): mean_yearly_data = yearly_data.mean(dim='year_for_grouping', skipna=True).compute() else: yearly_data = getattr(grouped_data, operation)(dim=time_coord_name, skipna=True, **op_kwargs) mean_yearly_data = yearly_data.mean(dim='year_for_grouping', skipna=True) return mean_yearly_data def _calc_spell_counts(self, data_in, time_coord_name, threshold_val, min_consecutive_days, spell_type_is_above_thresh): """ Calculate the average number of spells per year. Vectorized implementation. A "spell" is a period of consecutive days meeting a condition for at least a minimum number of days. """ condition = (data_in > threshold_val) if spell_type_is_above_thresh else (data_in < threshold_val) # Calculate the length of each consecutive run consecutive_lengths = self._vectorized_consecutive_true_count(condition, dim=time_coord_name) # A spell of required duration is "born" when its length first equals the minimum duration. # This counts each spell exactly once. spell_is_born = (consecutive_lengths == min_consecutive_days) # The number of spells per year is the sum of these "births" return self._apply_yearly_op_then_mean(spell_is_born.astype(int), time_coord_name, 'sum', dask_op_name="spell counts") def _calc_days_above_or_below_threshold_mean(self, data_in, time_coord_name, threshold_val, is_above_op): """ Calculate the mean annual number of days above or below a threshold. Helper function that counts days per year meeting a condition and then averages these counts over all years. Parameters ---------- data_in : xr.DataArray Input data with a time dimension. time_coord_name : str Name of the time coordinate. threshold_val : float or xr.DataArray The threshold value. is_above_op : bool If True, counts days *above* the threshold. If False, counts days *below*. Returns ------- xr.DataArray A DataArray with the mean annual number of days meeting the condition. """ condition_met = (data_in >= threshold_val) if is_above_op else (data_in < threshold_val) return self._apply_yearly_op_then_mean(condition_met.astype(int), time_coord_name, 'sum', dask_op_name="days matching condition") def _calc_max_consecutive_days(self, data_in, time_coord_name, threshold_val, spell_type_is_above_thresh): """ Calculate the mean of the annual maximum number of consecutive days above or below a threshold. Vectorized implementation. """ condition = (data_in >= threshold_val) if spell_type_is_above_thresh else (data_in < threshold_val) # Get the length of each consecutive run of True values consecutive_lengths = self._vectorized_consecutive_true_count(condition, dim=time_coord_name) # Group by year and find the maximum length within each year, then average the maxima return self._apply_yearly_op_then_mean(consecutive_lengths, time_coord_name, 'max', dask_op_name="max consecutive days") def _calc_days_in_spell(self, data_in, time_coord_name, threshold_val, min_consecutive_days, spell_type_is_above_thresh): """ Calculate the mean annual number of days in spells (e.g., WSDI). A "spell" is a period of consecutive days meeting a condition for at least a minimum number of days. This function counts the total number of days within such spells. """ condition = (data_in >= threshold_val) if spell_type_is_above_thresh else (data_in < threshold_val) # Calculate the length of each consecutive run up to the current point consecutive_lengths = self._vectorized_consecutive_true_count(condition, dim=time_coord_name) # Identify the end of each consecutive run of True values. # A run ends if the current value is True and the next is False. is_spell_end = (condition & ~condition.shift({time_coord_name: -1}, fill_value=False)) # Get the total length of each spell at the point where the spell ends. # Where it's not a spell end, this will be NaN. spell_end_lengths = consecutive_lengths.where(is_spell_end) # Back-fill the total spell length over the duration of each spell. # This propagates the final length of a spell to all days within that spell. total_spell_lengths = spell_end_lengths.bfill(dim=time_coord_name) # Mask out days that were not part of any spell to begin with. total_spell_lengths = total_spell_lengths.where(condition, 0) # Identify which days are part of a spell that meets the minimum duration requirement. is_in_qualifying_spell = (total_spell_lengths >= min_consecutive_days) # Sum the number of qualifying days per year and then average over the years. return self._apply_yearly_op_then_mean(is_in_qualifying_spell.astype(int), time_coord_name, 'sum', dask_op_name="days in spell") # ============================================================================== # PUBLIC PLOTTING METHODS # ============================================================================== # -------------------------------------------------------------------------- # A. Basic Statistical Plots # --------------------------------------------------------------------------
[docs] def plot_mean(self, variable='air', latitude=None, longitude=None, level=None, time_range=None, season='annual', contour=False, figsize=(16, 10), cmap='coolwarm', land_only=False, levels=30, save_plot_path=None, title=None, projection='PlateCarree'): """ Plot the temporal mean of a variable over a specified period. Calculates and plots the mean of a given variable over the specified time, space, and level dimensions. This is a fundamental plot for understanding the basic climate state. Parameters ---------- variable : str, optional Name of the variable to plot. Defaults to 'air'. latitude : float, slice, or list, optional Latitude range for selection. Can be a single value, a list of values, or a slice object (e.g., slice(30, 60)). longitude : float, slice, or list, optional Longitude range for selection. Can be a single value, a list, or a slice (e.g., slice(-120, -80)). level : float or slice, optional Vertical level for data selection. A single value selects the nearest level. A slice (e.g., slice(500, 200)) will result in the data being averaged over that level range before the temporal mean is computed. time_range : slice, optional Time range for selection, specified as a slice of datetime-like objects or strings (e.g., slice('2000-01-01', '2010-12-31')). season : str, optional Season to calculate the mean for. Supported options are 'annual', 'jjas', 'djf', 'mam', 'son', 'jja'. Defaults to 'annual'. contour : bool, optional If True, use contour lines instead of filled contours. Defaults to False. figsize : tuple, optional Figure size in inches (width, height). Defaults to (16, 10). cmap : str, optional Colormap name for the plot. Defaults to 'coolwarm'. land_only : bool, optional If True, mask out ocean areas, plotting data only over land. Defaults to False. levels : int, optional Number of contour levels for the plot. Defaults to 30. save_plot_path : str or None, optional If provided, the path to save the plot figure to. title : str, optional The title for the plot. If not provided, a descriptive title will be generated automatically. projection : str, optional The name of the cartopy projection to use. Defaults to 'PlateCarree'. Returns ------- cartopy.mpl.geoaxes.GeoAxes The Axes object of the plot, allowing for further customization. See Also -------- plot_std_time : Plot the temporal standard deviation. plot_percentile_spatial : Plot a specific temporal percentile. Examples -------- >>> import xarray as xr >>> import climate_diagnostics >>> ds = xr.tutorial.load_dataset("air_temperature") >>> ds.climate_plots.plot_mean( ... variable='air', ... level=850, ... time_range=slice('2013-01', '2013-12'), ... season='djf' ... ) """ get_or_create_dask_client() # Step 1: Select the data based on user parameters selected_data, level_dim_name, level_op = self._select_data( variable, latitude, longitude, level, time_range ) # Step 2: If a level range was selected, average over it first current_data_for_ops = selected_data if level_op == 'range_selected' and level_dim_name and level_dim_name in current_data_for_ops.dims: current_data_for_ops = current_data_for_ops.mean(dim=level_dim_name, skipna=True) print(f"Averaging over selected levels for '{variable}'.") # Step 3: Apply seasonal filter data_season = filter_by_season(current_data_for_ops, season) if data_season.size == 0: raise ValueError(f"No data after selections and season filter ('{season}') for '{variable}'.") # Step 4: Calculate the temporal mean time_coord_name_actual = get_coord_name(data_season, ['time', 't']) mean_data = data_season if time_coord_name_actual and time_coord_name_actual in data_season.dims: if data_season.chunks: print(f"Computing time mean for '{variable}' using Dask...") with ProgressBar(): mean_data = data_season.mean(dim=time_coord_name_actual, skipna=True).compute() else: mean_data = data_season.mean(dim=time_coord_name_actual, skipna=True) elif time_coord_name_actual: print(f"Warning: Time coord '{time_coord_name_actual}' not a dimension for averaging. Plotting as is.") else: print(f"Warning: No time coord for averaging. Plotting as is.") # Step 5: Pass to the generic spatial plotting function return self._plot_spatial_data( mean_data, variable, selected_data.attrs, self._obj[variable], data_season, time_coord_name_actual, time_range, level_op, level_dim_name, season, contour, figsize, cmap, land_only, levels, save_plot_path, plot_operation_name="Average", title=title, projection=projection )
[docs] def plot_std_time(self, variable='air', latitude=None, longitude=None, level=None, time_range=None, season='annual', contour=False, figsize=(16,10), cmap='viridis', land_only = False, levels=30, save_plot_path = None, title=None, projection='PlateCarree'): """ Plot the temporal standard deviation of a variable. Calculates and plots the standard deviation of a given variable over time, which is a key measure of climate variability. Parameters ---------- variable : str, optional Name of the variable to plot. Defaults to 'air'. latitude : float, slice, or list, optional Latitude range for selection. Can be a single value, a list of values, or a slice object (e.g., slice(30, 60)). longitude : float, slice, or list, optional Longitude range for selection. Can be a single value, a list, or a slice (e.g., slice(-120, -80)). level : float or slice, optional Vertical level for data selection. A single value selects the nearest level. A slice (e.g., slice(500, 200)) will result in the data being averaged over that level range before the standard deviation is computed. time_range : slice, optional Time range for selection, specified as a slice of datetime-like objects or strings (e.g., slice('2000-01-01', '2010-12-31')). season : str, optional Season to calculate the standard deviation for. Supported options are 'annual', 'jjas', 'djf', 'mam', 'son', 'jja'. Defaults to 'annual'. contour : bool, optional If True, use contour lines instead of filled contours. Defaults to False. figsize : tuple, optional Figure size in inches (width, height). Defaults to (16, 10). cmap : str, optional Colormap name for the plot. Defaults to 'viridis'. land_only : bool, optional If True, mask out ocean areas. Defaults to False. levels : int, optional Number of contour levels. Defaults to 30. save_plot_path : str or None, optional If provided, the path to save the plot figure to. title : str, optional The title for the plot. If not provided, a descriptive title will be generated automatically. projection : str, optional The name of the cartopy projection to use. Defaults to 'PlateCarree'. Returns ------- cartopy.mpl.geoaxes.GeoAxes The Axes object of the plot. See Also -------- plot_mean : Plot the temporal mean. """ get_or_create_dask_client() # Step 1: Select the data based on user parameters selected_data, level_dim_name, level_op = self._select_data( variable, latitude, longitude, level, time_range ) # Step 2: If a level range was selected, average over it first current_data_for_ops = selected_data if level_op == 'range_selected' and level_dim_name and level_dim_name in current_data_for_ops.dims: current_data_for_ops = current_data_for_ops.mean(dim=level_dim_name, skipna=True) print(f"Averaging across selected levels for '{variable}' before calculating std dev.") # Step 3: Apply seasonal filter time_coord_name_actual = get_coord_name(current_data_for_ops, ['time', 't']) if not time_coord_name_actual or time_coord_name_actual not in current_data_for_ops.dims: raise ValueError(f"Std dev requires time dimension for '{variable}'.") data_season = filter_by_season(current_data_for_ops, season) if data_season.size == 0: raise ValueError(f"No data after selections and season filter ('{season}') for '{variable}'.") if data_season.sizes[time_coord_name_actual] < 2: raise ValueError(f"Std dev requires at least 2 time points (found {data_season.sizes[time_coord_name_actual]}).") # Step 4: Calculate the temporal standard deviation if data_season.chunks: print(f"Computing std dev over time for '{variable}' using Dask...") with ProgressBar(): std_data = data_season.std(dim=time_coord_name_actual, skipna=True).compute() else: std_data = data_season.std(dim=time_coord_name_actual, skipna=True) # Step 5: Pass to the generic spatial plotting function return self._plot_spatial_data( std_data, variable, selected_data.attrs, self._obj[variable], data_season, time_coord_name_actual, time_range, level_op, level_dim_name, season, contour, figsize, cmap, land_only, levels, save_plot_path, plot_operation_name="Temporal Standard Deviation", cbar_prefix="Std. Dev. of ", title=title, projection=projection )
[docs] def plot_percentile_spatial(self, variable='prate', percentile=95, latitude=None, longitude=None, level=None, time_range=None, contour=False, figsize=(16, 10), cmap='Blues', land_only=False, levels=30, save_plot_path=None, title=None, projection='PlateCarree'): """ Plot the spatial distribution of a temporal percentile for a variable. Calculates a given percentile (e.g., 95th) at each grid point over the time dimension and plots the resulting map. This is useful for identifying areas with extreme values. Parameters ---------- variable : str, optional Name of the variable. Defaults to 'prate'. percentile : int, optional The percentile to calculate (0-100). Defaults to 95. latitude : float, slice, or list, optional Latitude range for selection. Can be a single value, a list of values, or a slice object. longitude : float, slice, or list, optional Longitude range for selection. Can be a single value, a list, or a slice. level : float or slice, optional Vertical level for data selection. A single value selects the nearest level. A slice will result in the data being averaged over that range. time_range : slice, optional Time range for selection as a slice of datetime-like objects or strings. contour : bool, optional If True, use contour lines instead of filled contours. Defaults to False. figsize : tuple, optional Figure size in inches (width, height). Defaults to (16, 10). cmap : str, optional Colormap name for the plot. Defaults to 'Blues'. land_only : bool, optional If True, mask out ocean areas. Defaults to False. levels : int, optional Number of contour levels for the plot. Defaults to 30. save_plot_path : str or None, optional If provided, the path to save the plot figure to. title : str, optional The title for the plot. If not provided, a descriptive title will be generated automatically. projection : str, optional The name of the cartopy projection to use. Defaults to 'PlateCarree'. Returns ------- cartopy.mpl.geoaxes.GeoAxes The Axes object of the plot. See Also -------- plot_mean : Plot the temporal mean of a variable. """ get_or_create_dask_client() # Step 1: Validate input if not 0 <= percentile <= 100: raise ValueError(f"Percentile must be 0-100, got {percentile}") # Step 2: Select the data and handle level-based averaging selected_data, level_dim_name, level_op = self._select_data( variable, latitude, longitude, level, time_range ) current_data_for_ops = selected_data if level_op == 'range_selected' and level_dim_name and level_dim_name in current_data_for_ops.dims: current_data_for_ops = current_data_for_ops.mean(dim=level_dim_name, skipna=True) print(f"Averaging across selected levels for '{variable}' before calculating percentile.") # Step 3: Calculate the percentile time_coord_name_actual = get_coord_name(current_data_for_ops, ['time', 't']) if not time_coord_name_actual or time_coord_name_actual not in current_data_for_ops.dims: raise ValueError(f"Percentile calculation requires a time dimension for '{variable}'.") if current_data_for_ops.chunks: print(f"Computing {percentile}th percentile for '{variable}' using Dask...") with ProgressBar(): percentile_data = current_data_for_ops.quantile(percentile / 100.0, dim=time_coord_name_actual, skipna=True).compute() else: percentile_data = current_data_for_ops.quantile(percentile / 100.0, dim=time_coord_name_actual, skipna=True) # Step 4: Pass to the generic spatial plotting function return self._plot_spatial_data( percentile_data, variable, selected_data.attrs, self._obj[variable], current_data_for_ops, time_coord_name_actual, time_range, level_op, level_dim_name, 'annual', contour, # Percentiles are season-agnostic figsize, cmap, land_only, levels, save_plot_path, plot_operation_name=f"{percentile}th Percentile", title=title, projection=projection )
# -------------------------------------------------------------------------- # B. Precipitation and Climate Indices (ETCCDI-style) # --------------------------------------------------------------------------
[docs] def plot_annual_sum_mean(self, variable='prate', latitude=None, longitude=None, level=None, time_range=None, contour=False, figsize=(16, 10), cmap='Blues', land_only=False, levels=30, save_plot_path=None, projection='PlateCarree'): """ Plot the mean of the annual total precipitation (PRCPTOT index). This function calculates the total precipitation for each year and then computes the mean of these annual totals. It is useful for visualizing changes in total precipitation over time. Parameters ---------- variable : str, optional Name of the variable. Defaults to 'prate'. latitude : float, slice, or list, optional Latitude range for selection. longitude : float, slice, or list, optional Longitude range for selection. level : float or slice, optional Vertical level for selection. If a slice is given, data is averaged over the level range. time_range : slice, optional Time range for selection. contour : bool, optional Use contour lines if True. Defaults to False. figsize : tuple, optional Figure size. Defaults to (16, 10). cmap : str, optional Colormap. Defaults to 'Blues'. land_only : bool, optional If True, mask out ocean areas. Defaults to False. levels : int, optional Number of contour levels. Defaults to 30. save_plot_path : str or None, optional If provided, the path to save the plot figure to. projection : str, optional The name of the cartopy projection to use. Defaults to 'PlateCarree'. Returns ------- cartopy.mpl.geoaxes.GeoAxes The Axes object of the plot. See Also -------- plot_max_1day_precip_mean : Plot the mean annual maximum 1-day precipitation. """ get_or_create_dask_client() # Step 1: Select the data based on user parameters selected_data, level_dim_name, level_op = self._select_data( variable, latitude, longitude, level, time_range ) current_data_for_ops = selected_data if level_op == 'range_selected' and level_dim_name and level_dim_name in current_data_for_ops.dims: current_data_for_ops = current_data_for_ops.mean(dim=level_dim_name, skipna=True) print(f"Averaging mean annual sum across selected levels for '{variable}'.") # Step 2: Calculate the mean of annual sums time_coord_name = get_coord_name(current_data_for_ops, ['time', 't']) if not time_coord_name or time_coord_name not in current_data_for_ops.dims: raise ValueError(f"Annual sum mean requires time dimension for '{variable}'.") mean_annual_sum = self._apply_yearly_op_then_mean(current_data_for_ops, time_coord_name, 'sum', dask_op_name="sums") # Step 3: Pass to the generic spatial plotting function return self._plot_spatial_data( mean_annual_sum, variable, selected_data.attrs, self._obj[variable], selected_data, time_coord_name, time_range, level_op, level_dim_name, 'annual', contour, figsize, cmap, land_only, levels, save_plot_path, plot_operation_name="Mean of Annual Total", cbar_prefix="Mean Annual ", projection=projection )
[docs] def plot_max_1day_precip_mean(self, variable='prate', latitude=None, longitude=None, level=None, time_range=None, contour=False, figsize=(16, 10), cmap='viridis', land_only=False, levels=30, save_plot_path=None, projection='PlateCarree'): """ Plot the mean of the annual maximum 1-day precipitation (Rx1day index). This function finds the highest precipitation amount in a single day for each year, averages these maxima, and plots the result. It is useful for analyzing changes in extreme precipitation events. Parameters ---------- variable : str, optional Name of the variable. Defaults to 'prate'. latitude : float, slice, or list, optional Latitude range for selection. longitude : float, slice, or list, optional Longitude range for selection. level : float or slice, optional Vertical level for selection. If a slice is given, data is averaged over the level range. time_range : slice, optional Time range for selection. contour : bool, optional Use contour lines if True. Defaults to False. figsize : tuple, optional Figure size. Defaults to (16, 10). cmap : str, optional Colormap. Defaults to 'viridis'. land_only : bool, optional If True, mask out ocean areas. Defaults to False. levels : int, optional Number of contour levels. Defaults to 30. save_plot_path : str or None, optional If provided, the path to save the plot figure to. projection : str, optional The name of the cartopy projection to use. Defaults to 'PlateCarree'. Returns ------- cartopy.mpl.geoaxes.GeoAxes The Axes object of the plot. See Also -------- plot_annual_sum_mean : Plot the mean annual total precipitation. """ get_or_create_dask_client() # Step 1: Select the data based on user parameters selected_data, level_dim_name, level_op = self._select_data( variable, latitude, longitude, level, time_range ) current_data_for_ops = selected_data if level_op == 'range_selected' and level_dim_name and level_dim_name in current_data_for_ops.dims: current_data_for_ops = current_data_for_ops.mean(dim=level_dim_name, skipna=True) print(f"Averaging Rx1day across selected levels for '{variable}'.") # Step 2: Calculate the mean of annual maxima time_coord_name = get_coord_name(current_data_for_ops, ['time', 't']) if not time_coord_name or time_coord_name not in current_data_for_ops.dims: raise ValueError(f"Rx1day requires time dimension for '{variable}'.") mean_rx1day = self._apply_yearly_op_then_mean(current_data_for_ops, time_coord_name, 'max', dask_op_name="maxima") # Step 3: Pass to the generic spatial plotting function return self._plot_spatial_data( mean_rx1day, variable, selected_data.attrs, self._obj[variable], selected_data, time_coord_name, time_range, level_op, level_dim_name, 'annual', contour, figsize, cmap, land_only, levels, save_plot_path, plot_operation_name="Mean of Annual Max 1-day", cbar_prefix="Mean Max 1-day ", projection=projection )
[docs] def plot_simple_daily_intensity_mean(self, variable='prate', latitude=None, longitude=None, level=None, time_range=None, contour=False, figsize=(16, 10), cmap='YlGnBu', land_only=False, levels=30, save_plot_path=None, projection='PlateCarree'): """ Plot the mean Simple Daily Intensity Index (SDII). SDII is the total annual precipitation divided by the number of wet days (days with precipitation above 1 mm). This index provides insight into changes in precipitation patterns and intensity. Parameters ---------- variable : str, optional Name of the variable. Defaults to 'prate'. latitude : float, slice, or list, optional Latitude range for selection. longitude : float, slice, or list, optional Longitude range for selection. level : float or slice, optional Vertical level for selection. If a slice is given, data is averaged over the level range. time_range : slice, optional Time range for selection. contour : bool, optional Use contour lines if True. Defaults to False. figsize : tuple, optional Figure size. Defaults to (16, 10). cmap : str, optional Colormap. Defaults to 'YlGnBu'. land_only : bool, optional If True, mask out ocean areas. Defaults to False. levels : int, optional Number of contour levels. Defaults to 30. save_plot_path : str or None, optional If provided, the path to save the plot figure to. projection : str, optional The name of the cartopy projection to use. Defaults to 'PlateCarree'. Returns ------- cartopy.mpl.geoaxes.GeoAxes The Axes object of the plot. See Also -------- plot_days_above_threshold_mean : Plot the mean annual number of days above a temperature threshold. """ get_or_create_dask_client() # Step 1: Select the data based on user parameters selected_data, level_dim_name, level_op = self._select_data( variable, latitude, longitude, level, time_range ) current_data_for_ops = selected_data if level_op == 'range_selected' and level_dim_name and level_dim_name in current_data_for_ops.dims: current_data_for_ops = current_data_for_ops.mean(dim=level_dim_name, skipna=True) print(f"Averaging SDII across selected levels for '{variable}'.") # Step 2: Calculate the SDII time_coord_name = get_coord_name(current_data_for_ops, ['time', 't']) if not time_coord_name or time_coord_name not in current_data_for_ops.dims: raise ValueError(f"SDII calculation requires time dimension for '{variable}'.") # Count wet days (above 1 mm) wet_days_count = (current_data_for_ops > 1e-5).astype(int) if wet_days_count.chunks: print(f"Computing annual wet day count for SDII using Dask...") with ProgressBar(): annual_wet_days = self._apply_yearly_op_then_mean(wet_days_count, time_coord_name, 'sum', dask_op_name="wet days").compute() with ProgressBar(): total_precipitation = self._apply_yearly_op_then_mean(current_data_for_ops, time_coord_name, 'sum', dask_op_name="total precip").compute() else: annual_wet_days = self._apply_yearly_op_then_mean(wet_days_count, time_coord_name, 'sum', dask_op_name="wet days") total_precipitation = self._apply_yearly_op_then_mean(current_data_for_ops, time_coord_name, 'sum', dask_op_name="total precip") sdii = total_precipitation / annual_wet_days.where(annual_wet_days > 0, np.nan) # Step 3: Pass to the generic spatial plotting function return self._plot_spatial_data( sdii, variable, selected_data.attrs, self._obj[variable], selected_data, time_coord_name, time_range, level_op, level_dim_name, 'annual', contour, figsize, cmap, land_only, levels, save_plot_path, plot_operation_name="Simple Daily Intensity Index (SDII)", cbar_prefix="Mean ", projection=projection )
[docs] def plot_days_above_threshold_mean(self, variable='tasmax', threshold=25, latitude=None, longitude=None, level=None, time_range=None, contour=False, figsize=(16, 10), cmap='Reds', land_only=False, levels=30, save_plot_path=None, projection='PlateCarree'): """ Plot the mean annual number of days where a variable is above a threshold. For temperature, this can represent "summer days" (e.g., tasmax > 25°C). Parameters ---------- variable : str, optional Name of the variable. Defaults to 'tasmax'. threshold : float, optional Threshold value. Defaults to 25. latitude : float, slice, or list, optional Latitude range for selection. longitude : float, slice, or list, optional Longitude range for selection. level : float or slice, optional Vertical level for selection. If a slice is given, data is averaged over the level range. time_range : slice, optional Time range for selection. contour : bool, optional Use contour lines if True. Defaults to False. figsize : tuple, optional Figure size. Defaults to (16, 10). cmap : str, optional Colormap. Defaults to 'Reds'. land_only : bool, optional If True, mask out ocean areas. Defaults to False. levels : int, optional Number of contour levels. Defaults to 30. save_plot_path : str or None, optional If provided, the path to save the plot figure to. projection : str, optional The name of the cartopy projection to use. Defaults to 'PlateCarree'. Returns ------- cartopy.mpl.geoaxes.GeoAxes The Axes object of the plot. See Also -------- plot_consecutive_dry_days_max_mean : Plot the mean annual maximum number of consecutive dry days. """ get_or_create_dask_client() # Step 1: Select the data based on user parameters selected_data, level_dim_name, level_op = self._select_data( variable, latitude, longitude, level, time_range ) current_data_for_ops = selected_data if level_op == 'range_selected' and level_dim_name and level_dim_name in current_data_for_ops.dims: current_data_for_ops = current_data_for_ops.mean(dim=level_dim_name, skipna=True) print(f"Averaging across selected levels for '{variable}' before calculating days above threshold.") # Step 2: Calculate the mean annual number of days above the threshold time_coord_name = get_coord_name(current_data_for_ops, ['time', 't']) if not time_coord_name or time_coord_name not in current_data_for_ops.dims: raise ValueError(f"Days above threshold calculation requires time dimension for '{variable}'.") # Count days above threshold days_above_threshold = (current_data_for_ops > threshold).astype(int) if days_above_threshold.chunks: print(f"Computing annual days above threshold for Dask...") with ProgressBar(): mean_days_above = self._apply_yearly_op_then_mean(days_above_threshold, time_coord_name, 'sum', dask_op_name="days above threshold").compute() else: mean_days_above = self._apply_yearly_op_then_mean(days_above_threshold, time_coord_name, 'sum', dask_op_name="days above threshold") # Step 3: Pass to the generic spatial plotting function units_str = f" ({variable_units})" if (variable_units := selected_data.attrs.get('units')) else "" return self._plot_spatial_data( mean_days_above, variable, selected_data.attrs, self._obj[variable], selected_data, time_coord_name, time_range, level_op, level_dim_name, 'annual', contour, figsize, cmap, land_only, levels, save_plot_path, plot_operation_name=f"Mean Annual Days > {threshold}{units_str}", cbar_prefix="Days > Threshold ", projection=projection )
[docs] def plot_consecutive_dry_days_max_mean(self, variable='prate', threshold=1, latitude=None, longitude=None, level=None, time_range=None, contour=False, figsize=(16, 10), cmap='YlOrBr', land_only=False, levels=30, save_plot_path=None, projection='PlateCarree'): """ Plot the mean of the annual maximum number of consecutive dry days (CDD). A "dry day" is defined as a day with precipitation below a certain threshold. This index is useful for identifying changes in dry spell patterns and durations. Parameters ---------- variable : str, optional Name of the variable. Defaults to 'prate'. threshold : float, optional Threshold value. Defaults to 1 mm/day (converted to appropriate units). latitude : float, slice, or list, optional Latitude range for selection. longitude : float, slice, or list, optional Longitude range for selection. level : float or slice, optional Vertical level for selection. If a slice is given, data is averaged over the level range. time_range : slice, optional Time range for selection. contour : bool, optional Use contour lines if True. Defaults to False. figsize : tuple, optional Figure size. Defaults to (16, 10). cmap : str, optional Colormap. Defaults to 'YlOrBr'. land_only : bool, optional If True, mask out ocean areas. Defaults to False. levels : int, optional Number of contour levels. Defaults to 30. save_plot_path : str or None, optional If provided, the path to save the plot figure to. projection : str, optional The name of the cartopy projection to use. Defaults to 'PlateCarree'. Returns ------- cartopy.mpl.geoaxes.GeoAxes The Axes object of the plot. See Also -------- plot_consecutive_wet_days : Plot the mean annual maximum number of consecutive wet days. """ get_or_create_dask_client() # Step 1: Select the data based on user parameters selected_data, level_dim_name, level_op = self._select_data( variable, latitude, longitude, level, time_range ) current_data_for_ops = selected_data if level_op == 'range_selected' and level_dim_name and level_dim_name in current_data_for_ops.dims: current_data_for_ops = current_data_for_ops.mean(dim=level_dim_name, skipna=True) print(f"Averaging CDD across selected levels for '{variable}'.") # Step 2: Calculate the mean of the annual maximum consecutive dry days time_coord_name = get_coord_name(current_data_for_ops, ['time', 't']) if not time_coord_name or time_coord_name not in current_data_for_ops.dims: raise ValueError(f"CDD calculation requires time dimension for '{variable}'.") mean_cdd = self._apply_yearly_op_then_mean(current_data_for_ops, time_coord_name, 'max', dask_op_name="maxima") # Step 3: Pass to the generic spatial plotting function return self._plot_spatial_data( mean_cdd, variable, selected_data.attrs, self._obj[variable], selected_data, time_coord_name, time_range, level_op, level_dim_name, "Annual", contour, figsize, cmap, land_only, levels, save_plot_path, plot_operation_name="Max Consecutive Dry Days (CDD)", cbar_prefix="Mean Max ", projection=projection )
[docs] def plot_warm_spell_duration_mean(self, variable='tasmax', threshold=25, min_consecutive_days=6, latitude=None, longitude=None, level=None, time_range=None, contour=False, figsize=(16, 10), cmap='Oranges', land_only=False, levels=30, save_plot_path=None, projection='PlateCarree'): """ Plot the mean Warm Spell Duration Index (WSDI). WSDI is the total number of days per year that are part of a "warm spell". A warm spell is defined as a period of at least `min_consecutive_days` where the temperature is above a certain threshold. Parameters ---------- variable : str, optional Name of the temperature variable. Defaults to 'tasmax'. threshold : float, optional Temperature threshold. Defaults to 25°C. min_consecutive_days : int, optional Minimum number of consecutive days to qualify as a warm spell. Defaults to 6. latitude : float, slice, or list, optional Latitude range for selection. longitude : float, slice, or list, optional Longitude range for selection. level : float or slice, optional Vertical level for selection. If a slice is given, data is averaged over the level range. time_range : slice, optional Time range for selection. contour : bool, optional Use contour lines if True. Defaults to False. figsize : tuple, optional Figure size. Defaults to (16, 10). cmap : str, optional Colormap. Defaults to 'Oranges'. land_only : bool, optional If True, mask out ocean areas. Defaults to False. levels : int, optional Number of contour levels. Defaults to 30. save_plot_path : str or None, optional If provided, the path to save the plot figure to. projection : str, optional The name of the cartopy projection to use. Defaults to 'PlateCarree'. Returns ------- cartopy.mpl.geoaxes.GeoAxes The Axes object of the plot. See Also -------- plot_cold_spell_duration_mean : Plot the mean Cold Spell Duration Index (CSDI). """ get_or_create_dask_client() # Step 1: Select the data based on user parameters selected_data, level_dim_name, level_op = self._select_data( variable, latitude, longitude, level, time_range ) current_data_for_ops = selected_data if level_op == 'range_selected' and level_dim_name and level_dim_name in current_data_for_ops.dims: current_data_for_ops = current_data_for_ops.mean(dim=level_dim_name, skipna=True) print(f"Averaging WSDI across selected levels for '{variable}'.") # Step 2: Calculate the mean annual number of days in warm spells time_coord_name = get_coord_name(current_data_for_ops, ['time', 't']) if not time_coord_name or time_coord_name not in current_data_for_ops.dims: raise ValueError(f"WSDI calculation requires time dimension for '{variable}'.") # Identify warm spells warm_spell_condition = (current_data_for_ops > threshold) if warm_spell_condition.chunks: print(f"Computing annual warm spell count for Dask...") with ProgressBar(): annual_warm_spells = self._apply_yearly_op_then_mean(warm_spell_condition.astype(int), time_coord_name, 'sum', dask_op_name="warm spells").compute() else: annual_warm_spells = self._apply_yearly_op_then_mean(warm_spell_condition.astype(int), time_coord_name, 'sum', dask_op_name="warm spells") # Step 3: Pass to the generic spatial plotting function return self._plot_spatial_data( annual_warm_spells, variable, selected_data.attrs, self._obj[variable], selected_data, time_coord_name, time_range, level_op, level_dim_name, "Annual", contour, figsize, cmap, land_only, levels, save_plot_path, plot_operation_name="Warm Spell Duration Index (WSDI)", cbar_prefix="Mean ", projection=projection )
[docs] def plot_cold_spell_duration_mean(self, variable='tasmin', threshold=0, min_consecutive_days=6, latitude=None, longitude=None, level=None, time_range=None, contour=False, figsize=(16, 10), cmap='Blues', land_only=False, levels=30, save_plot_path=None, projection='PlateCarree'): """ Plot the mean Cold Spell Duration Index (CSDI). CSDI is the total number of days per year that are part of a "cold spell". A cold spell is defined as a period of at least `min_consecutive_days` where the temperature is below a certain threshold. Parameters ---------- variable : str, optional Name of the temperature variable. Defaults to 'tasmin'. threshold : float, optional Temperature threshold. Defaults to 0°C. min_consecutive_days : int, optional Minimum number of consecutive days to qualify as a cold spell. Defaults to 6. latitude : float, slice, or list, optional Latitude range for selection. longitude : float, slice, or list, optional Longitude range for selection. level : float or slice, optional Vertical level for selection. If a slice is given, data is averaged over the level range. time_range : slice, optional Time range for selection. contour : bool, optional Use contour lines if True. Defaults to False. figsize : tuple, optional Figure size. Defaults to (16, 10). cmap : str, optional Colormap. Defaults to 'Blues'. land_only : bool, optional If True, mask out ocean areas. Defaults to False. levels : int, optional Number of contour levels. Defaults to 30. save_plot_path : str or None, optional If provided, the path to save the plot figure to. projection : str, optional The name of the cartopy projection to use. Defaults to 'PlateCarree'. Returns ------- cartopy.mpl.geoaxes.GeoAxes The Axes object of the plot. See Also -------- plot_warm_spell_duration_mean : Plot the mean Warm Spell Duration Index (WSDI). """ get_or_create_dask_client() # Step 1: Select the data based on user parameters selected_data, level_dim_name, level_op = self._select_data( variable, latitude, longitude, level, time_range ) current_data_for_ops = selected_data if level_op == 'range_selected' and level_dim_name and level_dim_name in current_data_for_ops.dims: current_data_for_ops = current_data_for_ops.mean(dim=level_dim_name, skipna=True) print(f"Averaging CSDI across selected levels for '{variable}'.") # Step 2: Calculate the mean annual number of days in cold spells time_coord_name = get_coord_name(current_data_for_ops, ['time', 't']) if not time_coord_name or time_coord_name not in current_data_for_ops.dims: raise ValueError(f"CSDI calculation requires time dimension for '{variable}'.") # Identify cold spells cold_spell_condition = (current_data_for_ops < threshold) if cold_spell_condition.chunks: print(f"Computing annual cold spell count for Dask...") with ProgressBar(): annual_cold_spells = self._apply_yearly_op_then_mean(cold_spell_condition.astype(int), time_coord_name, 'sum', dask_op_name="cold spells").compute() else: annual_cold_spells = self._apply_yearly_op_then_mean(cold_spell_condition.astype(int), time_coord_name, 'sum', dask_op_name="cold spells") # Step 3: Pass to the generic spatial plotting function return self._plot_spatial_data( annual_cold_spells, variable, selected_data.attrs, self._obj[variable], selected_data, time_coord_name, time_range, level_op, level_dim_name, "Annual", contour, figsize, cmap, land_only, levels, save_plot_path, plot_operation_name="Cold Spell Duration Index (CSDI)", cbar_prefix="Mean ", projection=projection )