diff --git a/src/mintpy/cli/plot_coherence_matrix.py b/src/mintpy/cli/plot_coherence_matrix.py index b7e6f9fda..9cbfc1741 100755 --- a/src/mintpy/cli/plot_coherence_matrix.py +++ b/src/mintpy/cli/plot_coherence_matrix.py @@ -61,6 +61,9 @@ def create_parser(subparsers=None): parser.add_argument('-t','--template', dest='template_file', help='temporal file.') + parser.add_argument('--time-axis', dest='time_axis', action='store_true', + help='Use continuous time axis instead of date indices for coherence matrix') + parser.add_argument('--save', dest='save_fig', action='store_true', help='save the figure') parser.add_argument('--nodisplay', dest='disp_fig', diff --git a/src/mintpy/ifgram_inversion.py b/src/mintpy/ifgram_inversion.py index 87aef8409..a05429dc4 100644 --- a/src/mintpy/ifgram_inversion.py +++ b/src/mintpy/ifgram_inversion.py @@ -841,7 +841,7 @@ def run_ifgram_inversion_patch(ifgram_file, box=None, ref_phase=None, obs_ds_nam # save result to output matrices ts[:, idx] = tsi.flatten() - inv_quality[idx] = inv_quali + inv_quality[idx] = np.atleast_1d(inv_quali)[0] num_inv_obs[idx] = num_obsi prog_bar.update(i+1, every=200, suffix=f'{i+1}/{num_pixel2inv_part} pixels') @@ -862,7 +862,7 @@ def run_ifgram_inversion_patch(ifgram_file, box=None, ref_phase=None, obs_ds_nam # save result to output matrices ts[:, idx] = tsi.flatten() - inv_quality[idx] = inv_quali + inv_quality[idx] = np.atleast_1d(inv_quali)[0] num_inv_obs[idx] = num_obsi prog_bar.update(i+1, every=200, suffix=f'{i+1}/{num_pixel2inv} pixels') diff --git a/src/mintpy/plot_coherence_matrix.py b/src/mintpy/plot_coherence_matrix.py index 4e2386269..d427c9cff 100644 --- a/src/mintpy/plot_coherence_matrix.py +++ b/src/mintpy/plot_coherence_matrix.py @@ -62,12 +62,20 @@ class coherenceMatrixViewer(): def __init__(self, inps): # figure variables - self.figname = 'Coherence matrix' - self.fig_size = None - self.fig = None + self.figname_img = 'Image' + self.figsize_img = None + self.fig_img = None self.ax_img = None + self.cbar_img = None + self.img = None + + self.figname_mat = 'Coherence Matrix' + self.figsize_mat = None + self.fig_mat = None self.ax_mat = None + self.time_axis = getattr(inps, 'time_axis', False) + # copy inps to self object for key, value in inps.__dict__.items(): setattr(self, key, value) @@ -89,11 +97,31 @@ def open(self): self = read_network_info(self) # auto figure size - if not self.fig_size: + if not self.figsize_img: ds_shape = readfile.read(self.img_file)[0].shape - fig_size = pp.auto_figure_size(ds_shape, disp_cbar=True, scale=0.7) - self.fig_size = [fig_size[0]+fig_size[1], fig_size[1]] - vprint(f'create figure in size of {self.fig_size} inches') + self.figsize_img = pp.auto_figure_size(ds_shape, disp_cbar=True, scale=0.7) + vprint(f'create image figure in size of {self.figsize_img} inches') + + if not self.figsize_mat: + num_ifg = len(self.date12_list) + if num_ifg <= 50: + self.figsize_mat = [6, 5] + elif num_ifg <= 100: + self.figsize_mat = [8, 6] + else: + self.figsize_mat = [10, 8] + vprint(f'create matrix figure in size of {self.figsize_mat} inches') + + if not hasattr(self, 'cmap_name'): + # Default colormap: use 'RdBu_truncate' for both timeaxis and normal mode (from CLI default) + # This matches the CLI default value + if self.time_axis: + self.cmap_name = 'RdBu_truncate' + else: + self.cmap_name = 'viridis' + if not hasattr(self, 'cmap_vlist'): + self.cmap_vlist = [0.0, 1.0] + self.colormap = pp.ColormapExt(self.cmap_name, vlist=self.cmap_vlist).colormap # read aux data # 1. temporal coherence value @@ -111,11 +139,27 @@ def open(self): def plot(self): - # Figure 1 - self.fig = plt.figure(self.figname, figsize=self.fig_size) - # Axes 1 - Image - self.ax_img = self.fig.add_axes([0.05, 0.1, 0.4, 0.8]) + # Figure 1 - Image + self.fig_img, self.ax_img = plt.subplots(num=self.figname_img, figsize=self.figsize_img) + self.plot_init_image() + + # Figure 2 - Coherence Matrix + self.fig_mat, self.ax_mat = plt.subplots(num=self.figname_mat, figsize=self.figsize_mat) + if all(i is not None for i in self.yx): + self.plot_coherence_matrix4pixel(self.yx) + + # Link the canvas to the plots. + self.cid_img = self.fig_img.canvas.mpl_connect('button_press_event', self.update_coherence_matrix) + self.cid_mat = self.fig_mat.canvas.mpl_connect('button_press_event', self.update_coherence_matrix) + + if self.disp_fig: + plt.show() + return + + def plot_init_image(self): + """Plot the initial image""" + view_cmd = self.view_cmd.format(self.img_file) d_img, atr, view_inps = view.prep_slice(view_cmd) self.coord = ut.coordinate(atr) @@ -137,22 +181,78 @@ def plot(self): self.ax_img = view.plot_slice(self.ax_img, d_img, atr, view_inps)[0] self.fig_coord = view_inps.fig_coord - # Axes 2 - coherence matrix - self.ax_mat = self.fig.add_axes([0.55, 0.125, 0.40, 0.75]) - self.colormap = pp.ColormapExt(self.cmap_name, vlist=self.cmap_vlist).colormap - if all(i is not None for i in self.yx): - self.plot_coherence_matrix4pixel(self.yx) - # Link the canvas to the plots. - self.cid = self.fig.canvas.mpl_connect('button_press_event', self.update_coherence_matrix) - if self.disp_fig: - plt.show() + self.fig_img.canvas.manager.set_window_title(self.figname_img) + self.fig_img.tight_layout() + + def plot_coherence_matrix4pixel_time_axis(self, yx): + """Plot coherence matrix with continuous time axis for one pixel + Parameters: yx : list of 2 int + """ + self.ax_mat.cla() + + # read coherence + box = (yx[1], yx[0], yx[1]+1, yx[0]+1) + coh = readfile.read(self.ifgram_file, datasetName='coherence', box=box)[0] + + # ex_date for pixel-wise masking during network inversion + ex_date12_list = self.ex_date12_list[:] #local copy + if self.min_coh_used > 0.: + ex_date12_list += np.array(self.date12_list)[coh < self.min_coh_used].tolist() + ex_date12_list = sorted(list(set(ex_date12_list))) + + # prep metadata + plotDict = {} + plotDict['fig_title'] = f'Y = {yx[0]}, X = {yx[1]}' + # display temporal coherence value of the pixel + if self.tcoh_file: + tcoh = self.tcoh[yx[0], yx[1]] + plotDict['fig_title'] += f', tcoh = {tcoh:.2f}' + plotDict['colormap'] = self.colormap + # cmap_vlist is [start, jump, end] for truncated colormap, but vlim needs [vmin, vmax] + if len(self.cmap_vlist) >= 2: + plotDict['vlim'] = [self.cmap_vlist[0], self.cmap_vlist[-1]] + else: + plotDict['vlim'] = [0.0, 1.0] + plotDict['cbar_label'] = 'Coherence' + plotDict['disp_legend'] = False + + # plot using the utility function + _, _ = pp.plot_coherence_matrix_time_axis( + self.ax_mat, + date12List=self.date12_list, + cohList=coh.tolist(), + date12List_drop=ex_date12_list, + p_dict=plotDict, + )[1:3] + + # Info + msg = f'pixel in yx = {tuple(yx)}, ' + msg += f'min/max spatial coherence: {np.nanmin(coh):.2f} / {np.nanmax(coh):.2f}, ' + if self.tcoh_file: + tcoh = self.tcoh[yx[0], yx[1]] + msg += f'temporal coherence: {tcoh:.2f}' + vprint(msg) + + self.ax_mat.annotate('ifgrams\navailable', xy=(0.05, 0.05), xycoords='axes fraction', fontsize=12) + self.ax_mat.annotate('ifgrams\nused', ha='right', xy=(0.95, 0.85), xycoords='axes fraction', fontsize=12) + + self.fig_mat.canvas.manager.set_window_title(self.figname_mat) + self.fig_mat.tight_layout() + + # Update figure + self.fig_mat.canvas.draw_idle() + self.fig_mat.canvas.flush_events() return def plot_coherence_matrix4pixel(self, yx): """Plot coherence matrix for one pixel Parameters: yx : list of 2 int """ + # Use time axis mode if enabled + if self.time_axis: + return self.plot_coherence_matrix4pixel_time_axis(yx) + self.ax_mat.cla() # read coherence @@ -203,12 +303,16 @@ def format_coord(x, y): msg += f'temporal coherence: {tcoh:.2f}' vprint(msg) + self.fig_mat.canvas.manager.set_window_title(self.figname_mat) + self.fig_mat.tight_layout() + # update figure - self.fig.canvas.draw_idle() - self.fig.canvas.flush_events() + self.fig_mat.canvas.draw_idle() + self.fig_mat.canvas.flush_events() return def update_coherence_matrix(self, event): + """Update coherence matrix when clicking on either window""" if event.inaxes == self.ax_img: if self.fig_coord == 'geo': yx = self.coord.lalo2yx(event.ydata, event.xdata) @@ -216,3 +320,17 @@ def update_coherence_matrix(self, event): yx = [int(event.ydata+0.5), int(event.xdata+0.5)] self.plot_coherence_matrix4pixel(yx) + + self.update_image_marker(yx) + elif event.inaxes == self.ax_mat: + pass + + def update_image_marker(self, yx): + """Update the marker point in the image window""" + if hasattr(self, 'pts_yx'): + for artist in self.ax_img.get_children(): + if hasattr(artist, 'get_marker') and artist.get_marker() == '^': + artist.remove() + + self.ax_img.plot(yx[1], yx[0], 'r^', markersize=10, markeredgecolor='black') + self.fig_img.canvas.draw_idle() diff --git a/src/mintpy/plot_network.py b/src/mintpy/plot_network.py index 31911f7bb..9076f7c4b 100644 --- a/src/mintpy/plot_network.py +++ b/src/mintpy/plot_network.py @@ -173,7 +173,7 @@ def plot_network(inps): # figure names ext = 'Ion.pdf' if os.path.basename(inps.file).startswith('ion') else '.pdf' fig_names = { - 'coherence' : [i+ext for i in ['pbaseHistory', 'coherenceHistory', 'coherenceMatrix', 'network']], + 'coherence' : [i+ext for i in ['pbaseHistory', 'coherenceHistory', 'coherenceMatrix', 'coherenceMatrixTimeAxis', 'network']], 'offsetSNR' : [i+ext for i in ['pbaseHistory', 'SNRHistory', 'SNRMatrix', 'network']], 'tbase' : [i+ext for i in ['pbaseHistory', 'tbaseHistory', 'tbaseMatrix', 'network']], 'pbase' : [i+ext for i in ['pbaseHistory', 'pbaseRangeHistory', 'pbaseMatrix', 'network']], @@ -204,7 +204,7 @@ def plot_network(inps): ) if inps.save_fig: fig.savefig(fig_names[1], **kwargs) - print(f'save figure to {fig_names[2]}') + print(f'save figure to {fig_names[1]}') # Fig 3 - Coherence Matrix fig_size3 = np.mean(inps.fig_size) @@ -218,9 +218,24 @@ def plot_network(inps): )[0] if inps.save_fig: fig.savefig(fig_names[2], **kwargs) - print(f'save figure to {fig_names[1]}') + print(f'save figure to {fig_names[2]}') + + # Fig 4 - Coherence Matrix with Time Axis + fig_size4 = np.mean(inps.fig_size) + fig, ax = plt.subplots(figsize=[fig_size4, fig_size4]) + ax = pp.plot_coherence_matrix_time_axis( + ax, + inps.date12List, + inps.cohList, + inps.date12List_drop, + p_dict=vars(inps), + )[0] + if inps.save_fig: + fig.savefig(fig_names[3], **kwargs) + print(f'save figure to {fig_names[3]}') - # Fig 4 - Interferogram Network + # Fig 5 - Interferogram Network (or Fig 4 if cohList is None) + fig_idx = 4 if inps.cohList is not None else 3 fig, ax = plt.subplots(figsize=inps.fig_size) ax = pp.plot_network( ax, @@ -231,8 +246,8 @@ def plot_network(inps): inps.date12List_drop, ) if inps.save_fig: - fig.savefig(fig_names[3], **kwargs) - print(f'save figure to {fig_names[3]}') + fig.savefig(fig_names[fig_idx], **kwargs) + print(f'save figure to {fig_names[fig_idx]}') if inps.disp_fig: print('showing ...') diff --git a/src/mintpy/utils/plot.py b/src/mintpy/utils/plot.py index 4c647482c..6a5637371 100644 --- a/src/mintpy/utils/plot.py +++ b/src/mintpy/utils/plot.py @@ -11,6 +11,7 @@ import datetime as dt import os import warnings +from datetime import datetime, timedelta import h5py import matplotlib as mpl @@ -959,6 +960,376 @@ def plot_coherence_matrix(ax, date12List, cohList, date12List_drop=[], p_dict={} return ax, coh_mat, im +def plot_coherence_matrix_time_axis(ax, date12List, cohList, date12List_drop=[], p_dict={}): + """Plot Coherence Matrix with continuous time axis + Parameters: ax : matplotlib.pyplot.Axes, + date12List : list of date12 in YYYYMMDD_YYYYMMDD format + cohList : list of float, coherence value + date12List_drop : list of date12 for date12 marked as dropped + p_dict : dict of plot setting + Returns: ax : matplotlib.pyplot.Axes + Z : 2D np.array, coherence value matrix in time grid + mesh : matplotlib.collections.QuadMesh object + """ + # Figure Setting + if 'ds_name' not in p_dict.keys(): p_dict['ds_name'] = 'Coherence' + if 'fontsize' not in p_dict.keys(): p_dict['fontsize'] = 12 + if 'disp_title' not in p_dict.keys(): p_dict['disp_title'] = True + if 'fig_title' not in p_dict.keys(): p_dict['fig_title'] = '{} Matrix'.format(p_dict['ds_name']) + if 'colormap' not in p_dict.keys(): p_dict['colormap'] = 'RdBu_truncate' + if 'cbar_label' not in p_dict.keys(): p_dict['cbar_label'] = p_dict['ds_name'] + if 'vlim' not in p_dict.keys(): p_dict['vlim'] = (0.2, 1.0) + if 'disp_cbar' not in p_dict.keys(): p_dict['disp_cbar'] = True + if 'legend_loc' not in p_dict.keys(): p_dict['legend_loc'] = 'best' + if 'disp_legend' not in p_dict.keys(): p_dict['disp_legend'] = True + + # support input colormap: string for colormap name, or colormap object directly + if isinstance(p_dict['colormap'], str): + cmap = ColormapExt(p_dict['colormap']).colormap + elif isinstance(p_dict['colormap'], mpl.colors.LinearSegmentedColormap): + cmap = p_dict['colormap'] + else: + raise ValueError('unrecognized colormap input: {}'.format(p_dict['colormap'])) + + # Normalize date12 format + date12List = ptime.yyyymmdd_date12(date12List) + date12List_drop = ptime.yyyymmdd_date12(date12List_drop) if date12List_drop else [] + + # Convert date strings to datetime objects + date_list_normalized = [] + for date12 in date12List: + date1_str, date2_str = date12.split('_') + date_list_normalized.extend([date1_str, date2_str]) + date_list_normalized = sorted(list(set(date_list_normalized))) + date_list_normalized = ptime.yyyymmdd(date_list_normalized) + + date_objs = {} + for date_str in date_list_normalized: + try: + date_objs[date_str] = datetime.strptime(date_str, '%Y%m%d') + except ValueError: + # Fallback: try with YYMMDD format + if len(date_str) == 6: + date_objs[date_str] = datetime.strptime('20' + date_str, '%Y%m%d') + else: + raise + + # Create coherence dictionary + # Store both the normalized pair (for lookup) and the original order (for upper/lower triangle) + coh_dict = {} + coh_dict_ordered = {} # Store with original date order to determine upper/lower triangle + excluded_pairs = set() + for date12, coh_val in zip(date12List, cohList): + date1_str, date2_str = date12.split('_') + # Ensure we have datetime objects + if date1_str not in date_objs: + date1_str = ptime.yyyymmdd([date1_str])[0] + if date2_str not in date_objs: + date2_str = ptime.yyyymmdd([date2_str])[0] + + date1 = date_objs.get(date1_str) + date2 = date_objs.get(date2_str) + + if date1 is None: + date1 = datetime.strptime(date1_str, '%Y%m%d') + date_objs[date1_str] = date1 + if date2 is None: + date2 = datetime.strptime(date2_str, '%Y%m%d') + date_objs[date2_str] = date2 + + # Store as tuple (date1, date2) where date1 <= date2 for consistency + pair = (min(date1, date2), max(date1, date2)) + coh_dict[pair] = float(coh_val) + + # Store with original order to determine upper/lower triangle + # In date12 format, date1_str is master (earlier) and date2_str is slave (later) + # So if date1 < date2, it's upper triangle (idx1 < idx2) + coh_dict_ordered[(date1, date2)] = float(coh_val) + + # Mark excluded pairs + if date12 in date12List_drop: + excluded_pairs.add(pair) + + # Get all unique dates + all_dates = set() + for d1, d2 in coh_dict.keys(): + all_dates.add(d1) + all_dates.add(d2) + date_list = sorted(list(all_dates)) + + # Create continuous time grid (based on actual data points) + # First, calculate internal cell widths to determine expansion width + internal_widths = [] + for i in range(len(date_list)-1): + width = (date_list[i+1] - date_list[i]).days + internal_widths.append(width) + + # Calculate average internal cell width for expansion + if len(internal_widths) > 0: + avg_width = sum(internal_widths) / len(internal_widths) + else: + avg_width = 30 # fallback to 30 days if no internal cells + + # Expand first and last cells outward by half the average width + first_expansion = timedelta(days=avg_width / 2) + last_expansion = timedelta(days=avg_width / 2) + + grid_points = [date_list[0] - first_expansion] # starting point (expanded outward) + for i in range(len(date_list)-1): + mid_point = date_list[i] + (date_list[i+1] - date_list[i])/2 + grid_points.append(mid_point) + grid_points.append(date_list[-1] + last_expansion) # ending point (expanded outward) + + # Convert to days for plotting + base_date = min(date_list) + days_grid = [(d - base_date).days for d in grid_points] + + # Create meshgrid + X, Y = np.meshgrid(days_grid, days_grid) + + # Create value matrix, initialized with NaN (will display as white) + Z = np.full((len(grid_points)-1, len(grid_points)-1), np.nan) + + # Create a mapping from date to grid index for fast lookup + date_to_grid_idx = {} + for date in date_list: + # Find the grid index where this date falls + for grid_idx in range(len(grid_points)-1): + if grid_idx == 0: + # First cell: from grid_points[0] to grid_points[1] + if grid_points[0] <= date <= grid_points[1]: + date_to_grid_idx[date] = grid_idx + break + elif grid_idx == len(grid_points) - 2: + # Last cell: from grid_points[-2] to grid_points[-1] + if grid_points[grid_idx] < date <= grid_points[grid_idx+1]: + date_to_grid_idx[date] = grid_idx + break + else: + # Middle cells: from grid_points[i] to grid_points[i+1] + if grid_points[grid_idx] < date <= grid_points[grid_idx+1]: + date_to_grid_idx[date] = grid_idx + break + + # Fill value matrix directly from coherence dictionary + # Upper triangle (idx1 < idx2): only kept pairs (not in excluded_pairs), same as normal mode + # Lower triangle (idx1 > idx2): all pairs (including dropped ones), same as normal mode + # In date12 format, date1_str is master (earlier) and date2_str is slave (later) + # So typically d1 < d2, which means idx1 < idx2 (upper triangle) + for (d1, d2), cor in coh_dict_ordered.items(): + # Find grid indices for both dates + idx1 = date_to_grid_idx.get(d1) + idx2 = date_to_grid_idx.get(d2) + + # Only fill if both dates are in valid grid cells + if idx1 is not None and idx2 is not None: + # Check if this pair is excluded (using normalized pair) + pair_normalized = (min(d1, d2), max(d1, d2)) + is_excluded = pair_normalized in excluded_pairs + + if idx1 < idx2: + # Upper triangle: only fill if not excluded (i.e., kept pairs) + # This matches normal mode: coh_mat[idx1, idx2] = np.nan for dropped pairs + if not is_excluded: + Z[idx1, idx2] = cor + # Lower triangle: fill all pairs (including excluded ones) + Z[idx2, idx1] = cor + elif idx1 > idx2: + # Lower triangle: fill all pairs (including excluded ones) + Z[idx1, idx2] = cor + # else: diagonal is already handled by diag_Z + + # Create diagonal matrix for black diagonal cells + diag_Z = np.full_like(Z, np.nan) + num_cells = len(grid_points) - 1 + for i in range(min(num_cells, Z.shape[0], Z.shape[1])): + diag_Z[i, i] = 1.0 + + # Plot diagonal as black first (using gray_r colormap, where 1.0 = black) + if np.any(~np.isnan(diag_Z)): + ax.pcolormesh(X, Y, diag_Z, + cmap='gray_r', + vmin=0.0, + vmax=1.0, + shading='auto', + zorder=1) + + # Plot using pcolormesh for coherence values + cmap_plot = cmap.copy() + cmap_plot.set_bad('white') # NaN values will be white + + mesh = ax.pcolormesh(X, Y, Z, + cmap=cmap_plot, + vmin=p_dict['vlim'][0], + vmax=p_dict['vlim'][1], + shading='auto', + zorder=0) + + # Generate month ticks + min_date = min(date_list) + max_date = max(date_list) + + # If min_date is not the first day of month, start from next month 1st + if min_date.day > 1: + if min_date.month == 12: + current_date = min_date.replace(year=min_date.year+1, month=1, day=1) + else: + current_date = min_date.replace(month=min_date.month+1, day=1) + else: + current_date = min_date.replace(day=1) + + tick_dates = [] + while current_date <= max_date: + tick_dates.append(current_date) + # Get next month + if current_date.month == 12: + current_date = current_date.replace(year=current_date.year+1, month=1) + else: + current_date = current_date.replace(month=current_date.month+1) + + # Calculate tick positions (at month start) + tick_positions = [(d - base_date).days for d in tick_dates] + + # Calculate label positions (at middle of adjacent ticks) + label_positions = [] + month_labels = [] + is_january = [] + + for i in range(len(tick_dates)-1): + # Calculate middle point of adjacent ticks + mid_point = (tick_positions[i] + tick_positions[i+1]) / 2 + + # Only add label for odd months + if tick_dates[i].month % 2 == 1: + label_positions.append(mid_point) + month_labels.append(tick_dates[i].strftime('%-m')) + is_january.append(tick_dates[i].month == 1) + + # Separate January and other month tick positions + major_ticks = [pos for pos, date in zip(tick_positions, tick_dates) if date.month == 1] + minor_ticks = [pos for pos, date in zip(tick_positions, tick_dates) if date.month != 1] + + # Separate January and other odd month label positions + major_label_pos = [pos for pos, is_jan in zip(label_positions, is_january) if is_jan] + minor_label_pos = [pos for pos, is_jan in zip(label_positions, is_january) if not is_jan] + + # Separate labels + major_labels = [label for label, is_jan in zip(month_labels, is_january) if is_jan] + minor_labels = [label for label, is_jan in zip(month_labels, is_january) if not is_jan] + + # Set tick positions (no labels) + ax.set_xticks(major_ticks) # major ticks (January) + ax.set_xticks(minor_ticks, minor=True) # minor ticks (other months) + ax.set_yticks(major_ticks) + ax.set_yticks(minor_ticks, minor=True) + + # Set empty labels (we'll add labels separately with text) + ax.set_xticklabels([''] * len(major_ticks)) + ax.set_xticklabels([''] * len(minor_ticks), minor=True) + ax.set_yticklabels([''] * len(major_ticks)) + ax.set_yticklabels([''] * len(minor_ticks), minor=True) + + # Normal label positions + offset = (ax.get_ylim()[1] - ax.get_ylim()[0]) * 0.03 + year_offset = offset * 2.2 # Year labels below month labels + + # Add month labels + for pos, label in zip(major_label_pos, major_labels): + ax.text(pos, ax.get_ylim()[1] + offset, label, + horizontalalignment='center', verticalalignment='top', fontsize=10) + ax.text(ax.get_xlim()[0] - offset, pos, label, + horizontalalignment='right', verticalalignment='center', fontsize=10) + + for pos, label in zip(minor_label_pos, minor_labels): + ax.text(pos, ax.get_ylim()[1] + offset, label, + horizontalalignment='center', verticalalignment='top', fontsize=10) + ax.text(ax.get_xlim()[0] - offset, pos, label, + horizontalalignment='right', verticalalignment='center', fontsize=10) + + # Set tick line style + ax.tick_params(which='major', direction='out', length=6, width=1.1, + bottom=True, top=True, left=True, right=True) + ax.tick_params(which='minor', direction='out', length=3, width=1, + bottom=True, top=True, left=True, right=True) + + # Add year labels (at middle month of each year) + # Group tick_dates by year + from collections import defaultdict + year_groups = defaultdict(list) + for i, d in enumerate(tick_dates): + year_groups[d.year].append((i, tick_positions[i])) + + years = [] + year_positions = [] + for year in sorted(year_groups.keys()): + year_indices = year_groups[year] + if len(year_indices) > 0: + # Calculate middle position: average of first and last month positions + first_pos = year_indices[0][1] + last_pos = year_indices[-1][1] + middle_pos = (first_pos + last_pos) / 2 + years.append(str(year)) + year_positions.append(middle_pos) + + # Display year labels + for pos, year in zip(year_positions, years): + # X-axis: year labels at bottom (below month labels) + ax.text(pos, ax.get_ylim()[1] + year_offset, year, + horizontalalignment='center', verticalalignment='top', fontsize=10) + # Y-axis: year labels at left (below month labels), rotated 90 degrees counterclockwise + ax.text(ax.get_xlim()[0] - year_offset, pos, year, + horizontalalignment='right', verticalalignment='center', fontsize=10, + rotation=90) + # Invert Y axis + ax.invert_yaxis() + + # Colorbar + if p_dict['disp_cbar']: + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", "3%", pad="3%") + cbar = plt.colorbar(mesh, cax=cax) + cbar.set_label(p_dict['cbar_label'], fontsize=p_dict['fontsize']) + + if p_dict['disp_title']: + ax.set_title(p_dict['fig_title'], fontsize=p_dict['fontsize']) + + # Legend + if date12List_drop and p_dict['disp_legend']: + ax.plot([], [], label='Upper: Ifgrams used') + ax.plot([], [], label='Lower: Ifgrams all') + ax.legend(loc=p_dict['legend_loc'], handlelength=0) + + # Status bar - format coordinate display + def format_coord(x, y): + x_idx = np.argmin(np.abs(np.array(days_grid) - x)) + y_idx = np.argmin(np.abs(np.array(days_grid) - y)) + + # Clamp indices to valid range + x_idx = min(max(0, x_idx), len(grid_points) - 1) + y_idx = min(max(0, y_idx), len(grid_points) - 1) + + if x_idx < len(grid_points) and y_idx < len(grid_points): + date1 = grid_points[x_idx] + date2 = grid_points[y_idx] + date1_str = date1.strftime('%Y-%m-%d') + date2_str = date2.strftime('%Y-%m-%d') + + # Find coherence value (Z has shape (len(grid_points)-1, len(grid_points)-1)) + coh_val = np.nan + if x_idx < len(grid_points) - 1 and y_idx < len(grid_points) - 1: + coh_val = Z[y_idx, x_idx] + + if not np.isnan(coh_val): + return f'x={date1_str}, y={date2_str}, v={coh_val:.3f}' + else: + return f'x={date1_str}, y={date2_str}, v=NaN' + return '' + + ax.format_coord = format_coord + + return ax, Z, mesh + + def plot_num_triplet_with_nonzero_integer_ambiguity(fname, disp_fig=False, font_size=12, fig_size=[9,3]): """Plot the histogram for the number of triplets with non-zero integer ambiguity. diff --git a/tests/dem_error.py b/tests/dem_error.py index 6bbfe5ed4..a99724067 100755 --- a/tests/dem_error.py +++ b/tests/dem_error.py @@ -5,7 +5,6 @@ import argparse import datetime -import math import sys import numpy as np @@ -136,7 +135,7 @@ def test_dem_error_with_linear_defo(date_list, tbase, rel_tol=0.05, plot=False): # validate print(f'Specified DEM error: {delta_z_sim:.2f} m') print(f'Estimated DEM error: {delta_z_est[0]:.2f} m') - assert math.isclose(delta_z_sim, delta_z_est, rel_tol=rel_tol) + assert np.isclose(delta_z_sim, delta_z_est, rtol=rel_tol) print('Pass.') @@ -194,7 +193,7 @@ def test_dem_error_with_complex_defo(date_list, tbase, rel_tol=0.05, plot=False) # validate print(f'Specified DEM error: {delta_z_sim:.2f} m') print(f'Estimated DEM error: {delta_z_est[0]:.2f} m') - assert math.isclose(delta_z_sim, delta_z_est, rel_tol=rel_tol) + assert np.isclose(delta_z_sim, delta_z_est, rtol=rel_tol) print('Pass.')