def compute_peaks(
self,
beads_observations,
beads_mef_values,
):
"""Compute coordinates of each bead's peak in all channels"""
self._log.debug(
f'Computing peaks for beads observation of shape {beads_observations.shape}'
)
self._log.debug(f'MEF values: {beads_mef_values}')
if beads_observations.shape[0] == 0:
raise ValueError('No beads observations found')
if beads_mef_values.shape[0] == 0:
raise ValueError('No beads MEF values found')
# assert beads_observations.shape[1] == beads_mef_values.shape[1]
if beads_observations.shape[1] != beads_mef_values.shape[1]:
raise ValueError(
f'Number of channels in beads observations ({beads_observations.shape[1]}) '
f'does not match number of channels in beads MEF values ({beads_mef_values.shape[1]})'
)
if beads_observations.shape[0] > self.resample_observations_to:
reorder = np.random.choice(
beads_observations.shape[0], self.resample_observations_to, replace=False
)
beads_observations = beads_observations[reorder]
not_saturated = (
(beads_observations > self._saturation_thresholds[:, 0])
& (beads_observations < self._saturation_thresholds[:, 1])
).T
self._obs_tr = self._tr(beads_observations)
self._mef_tr = self._tr(beads_mef_values)
self._saturation_thresholds_tr = self._tr(self._saturation_thresholds)
self._log.debug(f'Computing votes')
# Compute observation weights for OT source marginal
# When sparse_peak_boost > 0, observations in sparse regions get upweighted via inverse density
n_obs, n_chan = self._obs_tr.shape
uniform = jnp.ones((n_obs, n_chan)) / n_obs
if self.sparse_peak_boost > 0:
obs_density = jnp.array([
gaussian_kde(self._obs_tr[:, c], bw_method=0.1)(self._obs_tr[:, c])
for c in range(n_chan)
]).T
inv_density = 1.0 / jnp.maximum(obs_density, 1e-10)
inv_density = inv_density / inv_density.sum(axis=0, keepdims=True)
obs_weights = (1 - self.sparse_peak_boost) * uniform + self.sparse_peak_boost * inv_density
else:
obs_weights = uniform
@jit
@partial(vmap, in_axes=(1, 1, 1))
def vote(chan_observations, chan_mef, chan_weights):
"""Compute the vote matrix"""
# it's a (CHANNEL, OBSERVATIONS, BEAD) matrix
# where each channel tells what affinity each observation has to each bead, from the channel's perspective.
# This is computed using optimal transport (it's the OT matrix)
# High values mean it's obvious in this channel that the observation should be paired with a certain bead.
# Low values mean it's not so obvious, usually because the observation is not in the valid range
# so there's a bunch of points around this one that could be paired with any of the remaining beads
# This is much more robust than just computing OT for all channels at once
# When chan_weights diverges from uniform, sparse regions get more influence (peak-driven vs mass-driven)
return Sinkhorn()(
LinearProblem(PointCloud(chan_observations[:, None], chan_mef[:, None]), a=chan_weights)
).matrix
votes = vote(self._obs_tr, self._mef_tr, obs_weights) # (CHANNELS, OBSERVATIONS, BEADS)
valid_votes = votes * not_saturated[:, :, None] + 1e-12 # (CHANNELS, OBSERVATIONS, BEADS)
vmat = (
np.sum(valid_votes, axis=0) / np.sum(not_saturated, axis=0)[:, None]
) # weighted average
# Use these votes to decide which beads are the most likely for each observation
# Tried with a softer version of this just in case, but I couldn't see any improvement
self._vmat = np.argmax(vmat, axis=1)[:, None] == np.arange(vmat.shape[1])[None, :]
# We add some tiny random normal noise to avoid singular matrix errors when computing the KDE
# on a bead that would have only the exact same value (which can happen when out of range)
self._peaks_max_x = self._obs_tr.max() * 1.05
self._peaks_min_x = self._obs_tr.min() * 0.95
noise_std = (self._peaks_max_x - self._peaks_min_x) / (self.density_resolution * 5)
obs = self._obs_tr + np.random.normal(0, noise_std, self._obs_tr.shape)
# Now we can compute the densities for each bead in each channel in order to locate the peaks
self._log.debug(f'Computing densities')
x = np.linspace(self._peaks_min_x, self._peaks_max_x, self.density_resolution)
w_kde = lambda s, w: gaussian_kde(s, weights=w, bw_method=self.density_bw_method)(x)
densities = jit(vmap(vmap(w_kde, in_axes=(None, 1)), in_axes=(1, None)))(obs, self._vmat)
densities = densities.transpose(1, 0, 2) # densities.shape is (BEADS, CHANNELS, RESOLUTION)
self._beads_densities = densities / np.max(densities, axis=2)[:, :, None]
# find the most likely peak positions for each bead using
# the average intensity weighted by density
unsaturated_density = (x > self._saturation_thresholds_tr[:, 0][:, None]) & (
x < self._saturation_thresholds_tr[:, 1][:, None]
)
rel_mask = self._beads_densities > self.relative_density_threshold
# forget about the unsaturated_density mask, it's not useful
w = self._beads_densities * rel_mask
# w.shape is (BEADS, CHANNELS, RESOLUTION)
is_zero = np.sum(w, axis=2) == 0
w = np.where(is_zero[:, :, None], 1, w)
self._weighted_densities = w
self._log.debug(f'w sum: {np.sum(w)}')
xx = np.tile(x, (self._beads_densities.shape[0], self._beads_densities.shape[1], 1))
self._bead_peak_locations = np.average(
xx, axis=2, weights=w
) # peaks.shape is (BEADS, CHANNELS)
# compute the overlap matrix
@jit
def overlap(X):
def overlap_1v1(x, y):
s = jnp.sum(jnp.minimum(x, y), axis=1)
return jnp.clip(s / jnp.minimum(jnp.sum(x, axis=1), jnp.sum(y, axis=1)), 0, 1)
return vmap(vmap(overlap_1v1, (0, None)), (None, 0))(X, X)
self._log.debug(f'Computing overlap matrix')
self._overlap = overlap(self._weighted_densities)