Skip to content

MEFBeadsTransform

{{ task_disclaimer }}

MEFBeadsTransform

Bases: Task

beads_data: np.ndarray - array of shape (n_events, n_channels) containing the beads data

Functions

load_cells_with_MEF_channels

load_cells_with_MEF_channels(current_loader: Callable, data: Any, column_order: List[str]) -> pd.DataFrame

Load the cells data and apply the calibration to MEF units

Source code in calibrie/mefbeadstransform.py
def load_cells_with_MEF_channels(
    self, current_loader: Callable, data: Any, column_order: List[str]
) -> pd.DataFrame:
    """
    Load the cells data and apply the calibration to MEF units
    """
    cells = current_loader(data, column_order)
    actual_columns = list(cells.columns)
    for c in actual_columns:
        if c not in self.use_channels:
            raise ValueError(f'Channel {c} not found in beads\' use_channels. Cannot calibrate')
    new_values = self.transform_channels_to_MEF(cells.values, actual_columns)
    cells = pd.DataFrame(new_values, columns=actual_columns)
    print(
        f"LOADER--mefbeadstransform: Loaded {len(cells)} cells with {len(cells.columns)} MEF units channels"
    )
    return cells

compute_peaks

compute_peaks(beads_observations, beads_mef_values)

Compute coordinates of each bead's peak in all channels

Source code in calibrie/mefbeadstransform.py
def compute_peaks(
    self,
    beads_observations,
    beads_mef_values,
):
    """Compute coordinates of each bead's peak in all channels"""

    self._log.debug(
        f'Computing peaks for beads observation of shape {beads_observations.shape}'
    )
    self._log.debug(f'MEF values: {beads_mef_values}')

    if beads_observations.shape[0] == 0:
        raise ValueError('No beads observations found')
    if beads_mef_values.shape[0] == 0:
        raise ValueError('No beads MEF values found')

    # assert beads_observations.shape[1] == beads_mef_values.shape[1]
    if beads_observations.shape[1] != beads_mef_values.shape[1]:
        raise ValueError(
            f'Number of channels in beads observations ({beads_observations.shape[1]}) '
            f'does not match number of channels in beads MEF values ({beads_mef_values.shape[1]})'
        )

    if beads_observations.shape[0] > self.resample_observations_to:
        reorder = np.random.choice(
            beads_observations.shape[0], self.resample_observations_to, replace=False
        )
        beads_observations = beads_observations[reorder]

    not_saturated = (
        (beads_observations > self._saturation_thresholds[:, 0])
        & (beads_observations < self._saturation_thresholds[:, 1])
    ).T

    self._obs_tr = self._tr(beads_observations)
    self._mef_tr = self._tr(beads_mef_values)
    self._saturation_thresholds_tr = self._tr(self._saturation_thresholds)
    self._log.debug(f'Computing votes')

    # Compute observation weights for OT source marginal
    # When sparse_peak_boost > 0, observations in sparse regions get upweighted via inverse density
    n_obs, n_chan = self._obs_tr.shape
    uniform = jnp.ones((n_obs, n_chan)) / n_obs
    if self.sparse_peak_boost > 0:
        obs_density = jnp.array([
            gaussian_kde(self._obs_tr[:, c], bw_method=0.1)(self._obs_tr[:, c])
            for c in range(n_chan)
        ]).T
        inv_density = 1.0 / jnp.maximum(obs_density, 1e-10)
        inv_density = inv_density / inv_density.sum(axis=0, keepdims=True)
        obs_weights = (1 - self.sparse_peak_boost) * uniform + self.sparse_peak_boost * inv_density
    else:
        obs_weights = uniform

    @jit
    @partial(vmap, in_axes=(1, 1, 1))
    def vote(chan_observations, chan_mef, chan_weights):
        """Compute the vote matrix"""
        # it's a (CHANNEL, OBSERVATIONS, BEAD) matrix
        # where each channel tells what affinity each observation has to each bead, from the channel's perspective.
        # This is computed using optimal transport (it's the OT matrix)
        # High values mean it's obvious in this channel that the observation should be paired with a certain bead.
        # Low values mean it's not so obvious, usually because the observation is not in the valid range
        # so there's a bunch of points around this one that could be paired with any of the remaining beads
        # This is much more robust than just computing OT for all channels at once
        # When chan_weights diverges from uniform, sparse regions get more influence (peak-driven vs mass-driven)
        return Sinkhorn()(
            LinearProblem(PointCloud(chan_observations[:, None], chan_mef[:, None]), a=chan_weights)
        ).matrix

    votes = vote(self._obs_tr, self._mef_tr, obs_weights)  # (CHANNELS, OBSERVATIONS, BEADS)
    valid_votes = votes * not_saturated[:, :, None] + 1e-12  # (CHANNELS, OBSERVATIONS, BEADS)
    vmat = (
        np.sum(valid_votes, axis=0) / np.sum(not_saturated, axis=0)[:, None]
    )  # weighted average

    # Use these votes to decide which beads are the most likely for each observation
    # Tried with a softer version of this just in case, but I couldn't see any improvement
    self._vmat = np.argmax(vmat, axis=1)[:, None] == np.arange(vmat.shape[1])[None, :]

    # We add some tiny random normal noise to avoid singular matrix errors when computing the KDE
    # on a bead that would have only the exact same value (which can happen when out of range)
    self._peaks_max_x = self._obs_tr.max() * 1.05
    self._peaks_min_x = self._obs_tr.min() * 0.95
    noise_std = (self._peaks_max_x - self._peaks_min_x) / (self.density_resolution * 5)
    obs = self._obs_tr + np.random.normal(0, noise_std, self._obs_tr.shape)

    # Now we can compute the densities for each bead in each channel in order to locate the peaks
    self._log.debug(f'Computing densities')
    x = np.linspace(self._peaks_min_x, self._peaks_max_x, self.density_resolution)
    w_kde = lambda s, w: gaussian_kde(s, weights=w, bw_method=self.density_bw_method)(x)
    densities = jit(vmap(vmap(w_kde, in_axes=(None, 1)), in_axes=(1, None)))(obs, self._vmat)
    densities = densities.transpose(1, 0, 2)  # densities.shape is (BEADS, CHANNELS, RESOLUTION)
    self._beads_densities = densities / np.max(densities, axis=2)[:, :, None]

    # find the most likely peak positions for each bead using
    # the average intensity weighted by density

    unsaturated_density = (x > self._saturation_thresholds_tr[:, 0][:, None]) & (
        x < self._saturation_thresholds_tr[:, 1][:, None]
    )

    rel_mask = self._beads_densities > self.relative_density_threshold
    # forget about the unsaturated_density mask, it's not useful

    w = self._beads_densities * rel_mask

    # w.shape is (BEADS, CHANNELS, RESOLUTION)

    is_zero = np.sum(w, axis=2) == 0
    w = np.where(is_zero[:, :, None], 1, w)

    self._weighted_densities = w
    self._log.debug(f'w sum: {np.sum(w)}')

    xx = np.tile(x, (self._beads_densities.shape[0], self._beads_densities.shape[1], 1))
    self._bead_peak_locations = np.average(
        xx, axis=2, weights=w
    )  # peaks.shape is (BEADS, CHANNELS)

    # compute the overlap matrix
    @jit
    def overlap(X):
        def overlap_1v1(x, y):
            s = jnp.sum(jnp.minimum(x, y), axis=1)
            return jnp.clip(s / jnp.minimum(jnp.sum(x, axis=1), jnp.sum(y, axis=1)), 0, 1)

        return vmap(vmap(overlap_1v1, (0, None)), (None, 0))(X, X)

    self._log.debug(f'Computing overlap matrix')

    self._overlap = overlap(self._weighted_densities)

plot_bead_peaks_diagnostics

plot_bead_peaks_diagnostics()

Plot diagnostics for the bead peaks computation using subplots instead of subfigures, with correct scaling for multiple axes.

Source code in calibrie/mefbeadstransform.py
def plot_bead_peaks_diagnostics(self):
    """
    Plot diagnostics for the bead peaks computation using subplots instead of subfigures,
    with correct scaling for multiple axes.
    """
    svmat = np.array(self._vmat)
    nbeads = self._bead_peak_locations.shape[0]
    NCHAN = len(self.use_channels)

    fig = plt.figure(figsize=(18, 25))
    gs = fig.add_gridspec(
        ncols=NCHAN + 1,
        nrows=NCHAN + 4,
        width_ratios=[0.2] + [1] * NCHAN,
        height_ratios=[1] * NCHAN + [0.1, 1.2, 0.1, 2],
        hspace=0.4,
        wspace=0.3,
    )

    # Assignment plot
    ax_assignment = fig.add_subplot(gs[:, 0])
    self.plot_assignment(ax_assignment)

    axes_densities = []
    for i in range(NCHAN):
        if i == 0:
            ax = fig.add_subplot(gs[i, 1:NCHAN])
        else:
            ax = fig.add_subplot(gs[i, 1:NCHAN], sharex=axes_densities[0])
        axes_densities.append(ax)
    self.plot_densities(axes_densities)

    # Overlap plot
    axes_overlap = [fig.add_subplot(gs[i, -1]) for i in range(NCHAN)]
    self.plot_overlap(axes_overlap)

    # Regressions plot
    axes_regressions = [fig.add_subplot(gs[NCHAN + 1, 1 + i]) for i in range(NCHAN)]
    self.plot_regressions(axes_regressions)

    # After correction plot
    axes_after = [fig.add_subplot(gs[-1, 1 + i]) for i in range(NCHAN)]
    self.plot_beads_after_correction(axes_after)

    # Function to add framed title
    def add_framed_title(axes, title, padding=0.0125):
        # Get the positions of the first and last axes in the group
        first_ax = axes[0]
        last_ax = axes[-1]

        # Calculate the bounding box for the group of axes
        bbox = first_ax.get_position()
        bbox.y0 = min(ax.get_position().y0 for ax in axes)
        bbox.y1 = max(ax.get_position().y1 for ax in axes)
        bbox.x1 = last_ax.get_position().x1

        # Add padding
        bbox.x0 -= padding
        bbox.x1 += padding
        bbox.y0 -= padding
        bbox.y1 += padding

        # Add the title
        fig.text(
            bbox.x0 + bbox.width / 2,
            bbox.y1,
            title,
            ha='center',
            va='bottom',
            fontsize=12,
            fontweight='bold',
            transform=fig.transFigure,
        )

    fig.tight_layout()
    # Add framed titles
    add_framed_title([ax_assignment], "Assignment")
    add_framed_title(
        axes_densities,
        "Density and assignment of observations to bead number\n(hatched = out of range)",
    )
    add_framed_title(axes_overlap, "Bead overlap matrices")
    add_framed_title(axes_regressions, "Mapping of Arbitrary Units to MEF")
    add_framed_title(axes_after, "Observed peaks alignment to targets after correction")
    # Adjust layout

    return fig