Source code for pema.compare_plots

import typing as ty
import matplotlib.pyplot as plt
import numpy as np
import pema
import strax
import straxen
from strax.utils import tqdm
from straxen.analyses.waveform_plot import time_and_samples


@straxen.mini_analysis(
    requires=('peaks', 'peak_basics'),
    default_time_selection='touching',
    warn_beyond_sec=60)
def plot_peaks(peaks,
               seconds_range,
               t_reference,
               include_info=None,
               show_largest=100,
               single_figure=True,
               figsize=(10, 4),
               xaxis=True,
               ):
    if single_figure:
        plt.figure(figsize=figsize)
    plt.axhline(0, c='k', alpha=0.2)

    peaks = peaks[np.argsort(-peaks['area'])[:show_largest]]
    peaks = strax.sort_by_time(peaks)

    for p in peaks:
        plot_peak(p,
                  t0=t_reference,
                  include_info=include_info,
                  color={0: 'gray', 1: 'b', 2: 'g'}[p['type']])

    if xaxis == 'since_start':
        seconds_range_xaxis(seconds_range, t0=seconds_range[0])
    elif xaxis:
        seconds_range_xaxis(seconds_range)
        plt.xlim(*seconds_range)

    plt.ylabel("Intensity [PE/ns]")
    if single_figure:
        plt.tight_layout()


[docs]def plot_peak(p, t0=None, center_time=True, include_info=None, **kwargs): x, y = time_and_samples(p, t0=t0) kwargs.setdefault('linewidth', 1) # Plot waveform plt.plot(x, y, drawstyle='steps-pre', **kwargs, ) if 'linewidth' in kwargs: del kwargs['linewidth'] kwargs['alpha'] = kwargs.get('alpha', 1) * 0.2 plt.fill_between(x, 0, y, step='pre', linewidth=0, **kwargs) # Mark extent with thin black line plt.plot([x[0], x[-1]], [y.max(), y.max()], c='k', alpha=0.3, linewidth=1, ) # Mark center time with thin black line if center_time: if t0 is None: t0 = p['time'] ct = (p['center_time'] - t0) / int(1e9) plt.axvline(ct, c='k', alpha=0.4, linewidth=1, linestyle='--', ) if include_info: info_str = '\n'.join([f'{inf}: {p[inf]:.1f}' for inf in include_info]) plt.text(x[-1], y.max(), info_str, fontsize='xx-small', ha='left', va='top', alpha=0.8, bbox=dict(boxstyle="round", fc="w", alpha=0.5, ) )
def _plot_truth(data, start_end, t_range): plt.title('Instructions') for pk, pi in enumerate( range(*strax.touching_windows(data, start_end)[0])): tpeak = data[pi] hatch_cycle = ['/', '*', '+', '|'] _t_range = tpeak[['time', 'endtime']] x = np.array(list(_t_range)) y = tpeak['n_pe'] / np.diff(x) ct = tpeak['t_mean_photon'] stype = tpeak['type'] plt.gca() plt.fill_between( [ x[0] / 1e9, ct / 1e9, x[-1] / 1e9, ], [0, 0, 0], [0, 2 * y[0], 0], color={1: 'blue', 2: 'green', 0: 'gray', 6: 'orange', 4: 'purple', }[stype], label=f'Peak S{stype}. {tpeak["n_pe"]} PE', alpha=0.4, hatch=hatch_cycle[pk] ) plt.ylabel('Intensity [PE/ns]') for t in t_range: axvline(t / 1e9, label=f't = {t}') plt.legend(loc='lower left', fontsize='x-small') def _plot_peak(st_default, truth_vs_default, default_label, peak_i, t_range, xlim, run_id, label_x_axis=False, ): plt.title(default_label) if run_id is None: run_id = truth_vs_default[peak_i]['run_id'] st_default.plot_peaks(run_id, single_figure=False, include_info=['area', 'rise_time', 'tight_coincidence'], time_range=t_range, xaxis=label_x_axis, ) for t in t_range: axvline(t / 1e9, label=t) if label_x_axis: seconds_range_xaxis(xlim) plt.xlim(*xlim) plt.text(0.05, 0.95, truth_vs_default[peak_i]['outcome'], transform=plt.gca().transAxes, ha='left', va='top', bbox=dict(boxstyle="round", fc="w", ) ) plt.text(0.05, 0.1, '\n'.join(f'{prop[:10]}: {truth_vs_default[peak_i][prop]:.1f}' for prop in ['rec_bias', 'acceptance_fraction']), transform=plt.gca().transAxes, fontsize='small', ha='left', va='bottom', bbox=dict(boxstyle="round", fc="w"), alpha=0.8, )
[docs]def compare_truth_and_outcome( st: strax.Context, data: np.ndarray, **kwargs ) -> None: """ Compare the outcomes of the truth and the reconstructed peaks :param st: the context of the current master, to compare with :param data: the data consistent with the default context, can be cut to select certain data :param match_fuzz: Extend loading peaks this many ns to allow for small shifts in reconstruction. Will extend the time range left and right :param plot_fuzz: Make the plot slightly larger with this many ns for readability :param max_peaks: max number of peaks to be shown. Set to 1 for plotting a singe peak. :param label: How to label the default reconstruction :param fig_dir: Where to save figures (if None, don't save) :param show: show the figures or not. :param randomize: randomly order peaks to get a random sample of <max_peaks> every time :param run_id: Optional argument in case run_id is not a field in the data. :param raw: include raw-records-trace :param pulse: plot raw-record traces. :return: None """ if kwargs: kwargs['different_by'] = None compare_outcomes(st=st, data=data, st_alt=None, data_alt=None, **kwargs, )
[docs]def compare_outcomes(st: strax.Context, data: np.ndarray, st_alt: ty.Optional[strax.Context] = None, data_alt: ty.Optional[np.ndarray] = None, match_fuzz: int = 500, plot_fuzz: int = 500, max_peaks: int = 10, default_label: str = 'default', custom_label: str = 'custom', fig_dir: ty.Union[None, str] = None, show: bool = True, randomize: bool = True, different_by: ty.Optional[ty.Union[bool, str]] = 'acceptance_fraction', run_id: ty.Union[None, str] = None, raw: bool = False, pulse: bool = True, ) -> None: """ Compare the outcomes of two contexts with one another. In order to allow for selections, we need to pass the data as second and third argument respectively. :param st: the context of the current master, to compare with st_custom :param data: the data consistent with the default context, can be cut to select certain data :param st_alt: context wherewith to compare st_default :param data_alt: the data with the custom context, should be same length as truth_vs_default :param match_fuzz: Extend loading peaks this many ns to allow for small shifts in reconstruction. Will extend the time range left and right :param plot_fuzz: Make the plot slightly larger with this many ns for readability :param max_peaks: max number of peaks to be shown. Set to 1 for plotting a singe peak. :param default_label: How to label the default reconstruction :param custom_label:How to label the custom reconstruction :param fig_dir: Where to save figures (if None, don't save) :param show: show the figures or not. :param randomize: randomly order peaks to get a random sample of <max_peaks> every time :param different_by: Field to filter waveforms by. Only show waveforms where this field is different in data. If False, plot any waveforms from the two data sets. :param run_id: Optional argument in case run_id is not a field in the data. :param raw: include raw-records-trace :param pulse: plot raw-record traces. :return: None """ if (st_alt is None) != (data_alt is None): raise RuntimeError('Both st_alt and data_alt should be specified simultaneously') _plot_difference = st_alt is not None if _plot_difference: _check_args(data, data_alt, run_id) peaks_idx = _get_peak_idxs_from_args(data, randomize, data_alt, different_by) else: _check_args(data, None, run_id) peaks_idx = _get_peak_idxs_from_args(data, randomize) for peak_i in tqdm(peaks_idx[:max_peaks]): try: if 'run_id' in data.dtype.names: run_mask = data['run_id'] == data[peak_i]['run_id'] run_id = data[peak_i]['run_id'] else: run_mask = np.ones(len(data), dtype=np.bool_) t_range, start_end, xlim = _get_time_ranges(data, peak_i, match_fuzz, plot_fuzz) axes = iter(_get_axes_for_compare_plot( 2 + int(_plot_difference) + int(raw) + int(pulse)) ) plt.sca(next(axes)) _plot_truth(data[run_mask], start_end, t_range) if raw: plt.sca(next(axes)) st.plot_records_matrix(run_id, raw=True, single_figure=False, time_range=t_range, time_selection='touching', ) for t in t_range: axvline(t / 1e9) if pulse: plt.sca(next(axes)) rr_simple_plot(st, run_id, t_range) plt.sca(next(axes)) _plot_peak(st, data, default_label, peak_i, t_range, xlim, run_id, label_x_axis=not _plot_difference, ) if _plot_difference: plt.sca(next(axes)) _plot_peak(st_alt, data_alt, custom_label, peak_i, t_range, xlim, run_id, label_x_axis=True, ) _save_and_show('example_wf_diff', fig_dir, show, peak_i) except (ValueError, RuntimeError) as e: print(f'Error making {peak_i}: {type(e)}, {e}') plt.show()
[docs]def rr_simple_plot(st, run_id, t_range): """ Plot some raw-record pulses within (touching) the t_range :param st: :param run_id: :param t_range: :param legend: :return: """ cmap = plt.cm.twilight(np.arange(straxen.n_tpc_pmts)) raw_records = st.get_array(run_id, 'raw_records', progress_bar=False, time_range=t_range, time_selection='touching', ) raw_records = np.sort(raw_records, order='channel') plt.ylabel('ADC counts') for rr in raw_records: y = rr['data'][:rr['length']] time = np.arange(len(y)) * rr['dt'] + rr['time'] ch = rr['channel'] idx = rr['record_i'] plt.plot(time / 1e9, y, label=f'ch{ch:03}: rec_{idx}', c=cmap[ch] ) for t in t_range: axvline(t / 1e9)
[docs]def axvline(v, **kwargs): vline_color = next(plt.gca()._get_lines.prop_cycler)['color'] plt.axvline(v, color=vline_color, **kwargs)
def _get_peak_idxs_from_args(data_1, randomize, data_2=None, different_by=None): if different_by is not None and different_by: assert data_2 is not None peaks_idx = np.where(data_1[different_by] != data_2[different_by])[0] else: peaks_idx = np.arange(len(data_1)) if randomize: np.random.shuffle(peaks_idx) return peaks_idx def _check_args(truth_vs_default, truth_vs_custom=None, run_id=None): if 'run_id' not in truth_vs_default.dtype.names and run_id is None: raise ValueError('Either need a run_id or data with a run_id field!') if truth_vs_custom is not None and len(truth_vs_custom) != len(truth_vs_default): raise ValueError('Got different lengths for truth_vs_custom and truth_vs_default') def _get_axes_for_compare_plot(n_axis): assert n_axis in [2, 3, 4, 5] _, axes = plt.subplots( n_axis, 1, figsize=(10 * (n_axis / 3), 10), sharex=True, gridspec_kw={'height_ratios': [0.5, 1, 1, 1][:n_axis]} ) return axes def _save_and_show(name, fig_dir, show, peak_i): if fig_dir: pema.save_canvas(f'{name}_{peak_i}', save_dir=fig_dir) if show: plt.show() def _get_time_ranges(truth_vs_custom, peak_i, matching_fuzz, plot_fuzz): t_range = (truth_vs_custom[peak_i]['time'] - matching_fuzz, truth_vs_custom[peak_i]['endtime'] + matching_fuzz) start_end = np.zeros(1, dtype=strax.time_fields) start_end['time'] = t_range[0] start_end['endtime'] = t_range[1] xlim = (t_range[0] - plot_fuzz) / 1e9, (t_range[1] + plot_fuzz) / 1e9 return t_range, start_end, xlim
[docs]def seconds_range_xaxis(seconds_range, t0=None): """Make a pretty time axis given seconds_range""" plt.xlim(*seconds_range) ax = plt.gca() # disable for now ax.ticklabel_format(useOffset=False) xticks = plt.xticks()[0] if not len(xticks): return # Format the labels # I am not very proud of this code... def chop(x): return np.floor(x).astype(np.int64) if t0 is None: xticks_ns = np.round(xticks * int(1e9)).astype(np.int64) else: xticks_ns = np.round((xticks - xticks[0]) * int(1e9)).astype(np.int64) sec = chop(xticks_ns // int(1e9)) ms = chop((xticks_ns % int(1e9)) // int(1e6)) us = chop((xticks_ns % int(1e6)) // int(1e3)) samples = chop((xticks_ns % int(1e3)) // 10) labels = [str(sec[i]) for i in range(len(xticks))] print_ns = np.any(samples != samples[0]) print_us = print_ns | np.any(us != us[0]) print_ms = print_us | np.any(ms != ms[0]) if print_ms and t0 is None: labels = [l + f'.{ms[i]:03}' for i, l in enumerate(labels)] if print_us: labels = [l + r' $\bf{' + f'{us[i]:03}' + '}$' for i, l in enumerate(labels)] if print_ns: labels = [l + f' {samples[i]:02}0' for i, l in enumerate(labels)] plt.xticks(ticks=xticks, labels=labels, rotation=90) else: labels = list(chop((xticks_ns // 10) * 10)) labels[-1] = "" plt.xticks(ticks=xticks, labels=labels, rotation=0) if t0 is None: plt.xlabel("Time since run start [sec]") else: plt.xlabel("Time [ns]")