import warnings
import strax
import numba
import numpy as np
import pema
import logging
import straxen
from .matching import INT_NAN
export, __all__ = strax.exporter()
log = logging.getLogger('Pema matching')
[docs]@export
class MatchPeaks(strax.OverlapWindowPlugin):
"""
Match WFSim truth to the outcome peaks. To this end use the
matching algorithm of pema. Assign a peak-id to both the truth
and the reconstructed peaks to be able to match the two. Also
define the outcome of the matching (see pema.matching for
possible outcomes).
"""
__version__ = '0.5.0'
depends_on = ('truth', 'truth_id', 'peak_basics', 'peak_id')
provides = 'truth_matched'
data_kind = 'truth'
truth_lookup_window = straxen.URLConfig(
default=int(1e9),
help='Look back and forth this many ns in the truth info',
)
[docs] def compute(self, truth, peaks):
log.debug(f'Starting {self.__class__.__name__}')
truth = pema.append_fields(truth, 'area', truth['raw_area'])
# hack endtime
log.warning(f'Patching endtime in the truth')
truth['endtime'] = truth['t_last_photon'].copy()
log.info('Starting matching')
truth_vs_peak, peak_vs_truth = pema.match_peaks(truth, peaks)
# copy to the result buffer
res_truth = np.zeros(len(truth), dtype=self.dtype)
for k in self.dtype.names:
res_truth[k] = truth_vs_peak[k]
return res_truth
[docs] def get_window_size(self):
return self.config['truth_lookup_window']
[docs] def infer_dtype(self):
dtype = strax.dtypes.time_fields + [
((f'Id of element in truth', 'id'), np.int64),
((f'Outcome of matching to peaks', 'outcome'), pema.matching.OUTCOME_DTYPE),
((f'Id of matching element in peaks', 'matched_to'), np.int64)
]
return dtype
[docs]@export
class AcceptanceComputer(strax.Plugin):
"""
Compute the acceptance of the matched peaks. This is done on the
basis of arbitrary settings to allow better to disentangle
possible scenarios that might be undesirable (like splitting
an S2 into small S1 signals that could affect event
reconstruction).
"""
__version__ = '2.0.2'
depends_on = ('truth', 'truth_matched', 'peak_basics', 'peak_id')
provides = 'match_acceptance'
data_kind = 'truth'
keep_peak_fields = straxen.URLConfig(
default=('area', 'range_50p_area', 'area_fraction_top', 'rise_time', 'tight_coincidence'),
help='Add the reconstructed value of these variables',
)
penalty_s2_by = straxen.URLConfig(
default=(('misid_as_s1', -1.), ('split_and_misid', -1.),),
help='Add a penalty to the acceptance fraction if the peak has the '
'outcome. Should be tuple of tuples where each tuple should '
'have the format of (outcome, penalty_factor)',
)
min_s2_bias_rec = straxen.URLConfig(
default=0.85,
help='If the S2 fraction is greater or equal than this, consider a '
'peak successfully found even if it is split or chopped.',
)
[docs] def compute(self, truth, peaks):
res = np.zeros(len(truth), self.dtype)
res['time'] = truth['time']
res['endtime'] = strax.endtime(truth)
res['is_found'] = truth['outcome'] == 'found'
peak_idx = truth['matched_to']
mask = peak_idx != INT_NAN
if np.sum(mask):
# need to get at least one peak for each, even if we are going to remove those later
sel_from_peaks = peak_idx[mask]
sel_peaks = peaks[get_idx(sel_from_peaks, peaks['id'], INT_NAN)]
if len(sel_peaks) != len(sel_from_peaks):
raise ValueError(f'Got {len(sel_peaks)} != {len(sel_from_peaks)}')
not_match = sel_from_peaks != sel_peaks['id']
if np.any(not_match):
for i, t_i, p_i in zip(not_match, sel_from_peaks, sel_peaks['id']):
print(i, t_i, p_i)
raise ValueError
for k in self.keep_peak_fields:
res[f'rec_{k}'][mask] = sel_peaks[k]
res['rec_bias'] = res['rec_area'] / truth['raw_area']
# S1 acceptance is simply is the peak found or not
s1_mask = truth['type'] == 1
res['acceptance_fraction'][s1_mask] = res['is_found'][s1_mask].astype(np.float64)
# For the S2 acceptance we calculate an arbitrary acceptance
# that takes into account penalty factors and that S2s may be
# split (as long as their bias fraction is not too small).
s2_mask = truth['type'] == 2
s2_outcomes = truth['outcome'][s2_mask].copy()
s2_acceptance = (res[s2_mask]['rec_bias'] > self.config['min_s2_bias_rec']).astype(
np.float64)
for outcome, penalty in self.config['penalty_s2_by']:
s2_out_mask = s2_outcomes == outcome
s2_acceptance[s2_out_mask] = penalty
# now update the acceptance fraction in the results
res['acceptance_fraction'][s2_mask] = s2_acceptance
return res
[docs] def infer_dtype(self):
dtype = strax.dtypes.time_fields + [
((f'Is the peak tagged "found" in the reconstructed data',
'is_found'), np.bool_),
((f'Acceptance of the peak can be negative for penalized reconstruction',
'acceptance_fraction'),
np.float64),
((f'Reconstruction bias 1 is perfect, 0.1 means incorrect',
'rec_bias'),
np.float64),
]
for descr in self.deps['peak_basics'].dtype_for('peak_basics').descr:
# Add peak fields
field = descr[0][1]
if field in self.keep_peak_fields:
dtype += [((descr[0][0], f'rec_{field}'), descr[1])]
return dtype
[docs] def setup(self):
assert 'area' in self.keep_peak_fields
class AcceptanceExtended(strax.MergeOnlyPlugin):
"""Merge the matched acceptance to the extended truth"""
__version__ = '0.1.0'
depends_on = ('match_acceptance', 'truth', 'truth_id', 'truth_matched')
provides = 'match_acceptance_extended'
data_kind = 'truth'
save_when = strax.SaveWhen.TARGET
def setup(self):
warnings.warn(f'match_acceptance_extended is deprecated use truth_extended',
DeprecationWarning)
super().setup()
[docs]@export
class TruthExtended(strax.MergeOnlyPlugin):
"""Merge the matched acceptance to the extended truth"""
__version__ = '0.1.0'
depends_on = ('match_acceptance', 'truth', 'truth_id', 'truth_matched')
provides = 'truth_extended'
data_kind = 'truth'
save_when = strax.SaveWhen.TARGET
class MatchEvents(strax.OverlapWindowPlugin):
"""
Match WFSim truth to the outcome peaks. To this end use the
matching algorithm of pema. Assign a peak-id to both the truth
and the reconstructed peaks to be able to match the two. Also
define the outcome of the matching (see pema.matching for
possible outcomes).
"""
__version__ = '0.1.0'
depends_on = ('truth', 'events')
provides = 'truth_events'
data_kind = 'truth_events'
truth_lookup_window = straxen.URLConfig(
default=int(1e9),
help='Look back and forth this many ns in the truth info',
)
check_event_endtime = straxen.URLConfig(
default=True,
help='Check that all events have a non-zero duration.',
)
sim_id_field = straxen.URLConfig(
default='event_number',
help='Group the truth info by this field. Options: ["event_number", "g4id"]',
)
dtype = strax.dtypes.time_fields + [
((f'First event number in event datatype within the truth event', 'start_match'), np.int64),
((f'Last (inclusive!) event number in event datatype within the truth event', 'end_match'),
np.int64),
((f'Outcome of matching to events', 'outcome'), pema.matching.OUTCOME_DTYPE),
((f'Truth event number', 'truth_number'), np.int64),
]
def compute(self, truth, events):
unique_numbers = np.unique(truth[self.sim_id_field])
res = np.zeros(len(unique_numbers), self.dtype)
res['truth_number'] = unique_numbers
fill_start_end(truth, res)
if self.check_event_endtime:
assert np.all(res['endtime'] > res['time'])
assert np.all(np.diff(res['time']) > 0)
tw = strax.touching_windows(events, res)
tw_start = tw[:, 0]
tw_end = tw[:, 1] - 1 # NB! This is now INCLUSIVE
diff = np.diff(tw, axis=1)[:, 0]
found = diff > 0
# None unless found
res['start_match'][~found] = pema.matching.INT_NAN
res['end_match'][~found] = pema.matching.INT_NAN
res['start_match'][found] = events[tw_start[found]]['event_number']
res['end_match'][found] = events[tw_end[found]]['event_number']
res['outcome'] = self.outcomes(diff)
return res
def get_window_size(self):
return self.config['truth_lookup_window']
@staticmethod
def outcomes(diff):
"""Classify if the event_number"""
outcome = np.empty(len(diff), dtype=pema.matching.OUTCOME_DTYPE)
not_found_mask = diff < 1
one_found_mask = diff == 1
many_found_mask = diff > 1
outcome[not_found_mask] = 'missed'
outcome[one_found_mask] = 'found'
outcome[many_found_mask] = 'split'
return outcome
class PeakId(strax.Plugin):
"""Add id field to datakind"""
depends_on = 'peak_basics'
provides = 'peak_id'
data_kind = 'peaks'
__version__ = '0.0.0'
peaks_seen = 0
save_when = strax.SaveWhen.TARGET
def infer_dtype(self):
dtype = strax.time_fields
id_field = [((f'Id of element in {self.data_kind}', 'id'), np.int64), ]
return dtype + id_field
def compute(self, peaks):
res = np.zeros(len(peaks), dtype=self.dtype)
res['time'] = peaks['time']
res['endtime'] = peaks['endtime']
peak_id = np.arange(len(peaks)) + self.peaks_seen
res['id'] = peak_id
self.peaks_seen += len(peaks)
return res
class TruthId(PeakId):
depends_on = 'truth'
provides = 'truth_id'
data_kind = 'truth'
__version__ = '0.0.0'
def compute(self, truth):
assert_ordered_truth(truth)
return super().compute(truth)
def fill_start_end(truth, truth_event, end_field='endtime'):
"""Set the 'time' and 'endtime' fields based on the truth"""
truth_number = truth['event_number']
starts = truth['time']
stops = truth[end_field]
_fill_start_end(truth_number, stops, starts, truth_event)
@numba.njit()
def _fill_start_end(truth_number, stops, starts, truth_event):
for i, ev_i in enumerate(truth_event['truth_number']):
mask = truth_number == ev_i
start = starts[mask].min()
stop = stops[mask].max()
truth_event['time'][i] = start
truth_event['endtime'][i] = stop
def assert_ordered_truth(truth):
assert np.all(np.diff(truth['time']) >= 0), "truth is not sorted!"
@numba.njit
def get_idx(search_item, in_list, not_found=-99999):
"""Get index in <in_list> where the value is <searc_value>
Assumes that <in_list> is sorted!
:returns:
a list of length (search item) where each value refers to the item in <in_list>
"""
result = np.ones(len(search_item), dtype=np.int64) * not_found
look_from = 0
for i, search in enumerate(search_item):
for k, v in enumerate(in_list[look_from:]):
if v == search:
result[i] = look_from + k
look_from += k
break
for r in result:
if r == not_found:
raise ValueError()
return result