Skip to content

Colinearization

{{ task_disclaimer }}

Colinearization

Bases: Task

Task for colinearizing multi-channel fluorescence measurements.

Functions

initialize

initialize(ctx: Context) -> Context

Initialize the colinearization task with control data.

Source code in calibrie/colinearization.py
def initialize(self, ctx: Context) -> Context:
    """Initialize the colinearization task with control data."""
    t0 = time.time()

    # Extract required data from context
    self._controls_values = ctx.controls_values
    self._controls_masks = ctx.controls_masks
    self._channel_names = ctx.channel_names
    self._protein_names = ctx.protein_names
    self._reference_channels = ctx.reference_channels

    # Initialize processing pipeline
    autofluorescence = utils.estimate_autofluorescence(
        self._controls_values, self._controls_masks
    )

    # in initialize method, after corrected_controls calculation:
    corrected_controls = self._controls_values - autofluorescence
    self._corrected_saturation_thresholds = utils.estimate_saturation_thresholds_qtl(
        corrected_controls
    )

    self._null_observations = corrected_controls[np.all(self._controls_masks == 0, axis=1)]

    self._log.debug("Resampling observations for linearization")
    self._singleptr_observations = resampled_singleprt_ctrls(
        corrected_controls,
        self._protein_names,
        self._controls_masks,
        self._reference_channels,
        self.resample_N,
        npartitions=0,
    )

    self._not_saturated = (
        self._singleptr_observations > self._corrected_saturation_thresholds[:, 0]
    ) & (self._singleptr_observations < self._corrected_saturation_thresholds[:, 1])

    # Calculate ranges and quality metrics
    self._compute_ranges_and_metrics()

    # Find linearization path if not provided
    if not self._linearization_path and self._params is None:
        self._find_linearization_path()

    # Compute transformation parameters
    if self._params is None:
        self._log.debug("Computing linearization params")
        self._params = self.find_linearization_params(self._channel_names)

    new_loader = self.make_loader(ctx.cell_data_loader)
    self._log.debug(f"Linearization initialization done in {time.time() - t0:.1f}s")

    return Context(
        cell_data_loader=new_loader,
        saturation_thresholds=self._corrected_saturation_thresholds,
    )

find_linearization_params

find_linearization_params(channel_names, compute_coef_at_percentile=80, progress=False)

Find the parameters for linearization transformations.

Source code in calibrie/colinearization.py
def find_linearization_params(
    self, channel_names, compute_coef_at_percentile=80, progress=False
):
    """Find the parameters for linearization transformations."""
    assert self._linearization_path is not None

    params = [
        self._generate_single_identity_transform(yb)
        for yb in self._corrected_saturation_thresholds
    ]

    Ysingle_prime = self._singleptr_observations

    total_needed_linearizations = sum(
        len(extract_linearized(p)) for p in self._linearization_path
    )

    if progress:
        pbar = tqdm(total=total_needed_linearizations)
        msg = 'Linearization'

    for group in self._linearization_path:
        for step in group:
            for src_chan, dest_chan, _ in step.astype(int):
                if src_chan != dest_chan:
                    prot = self._best_range_prot_per_channel[src_chan]
                    x = self._singleptr_observations[prot, :, src_chan]
                    y = Ysingle_prime[prot, :, dest_chan]
                    w = (
                        self._not_saturated[prot, :, src_chan]
                        * self._not_saturated[prot, :, dest_chan]
                    )
                    # compute rough coefficient to express y as x * coef
                    percentile_bounds = np.percentile(
                        x[w], [compute_coef_at_percentile - 1, compute_coef_at_percentile + 1]
                    )
                    inbounds_x = (x[w] > percentile_bounds[0]) & (x[w] < percentile_bounds[1])
                    matching_x = x[w][inbounds_x]
                    matching_y = y[w][inbounds_x]
                    coef = np.nanmean(matching_x / matching_y)
                    xbounds = self._corrected_saturation_thresholds[src_chan]
                    params[src_chan] = self._regression(x, y * coef, w, xbounds)
                    rmse_before = np.sqrt(
                        np.nanmean((Ysingle_prime[prot, :, src_chan] - y) ** 2)
                    )
                    Ysingle_prime[prot, :, src_chan] = self._transform_chan(x, params[src_chan])
                    rmse_after = np.sqrt(
                        np.nanmean((Ysingle_prime[prot, :, src_chan] - y * coef) ** 2)
                    )
                    msg = f'Linearizing {channel_names[src_chan]} -> {channel_names[dest_chan]}: {rmse_before:.2f} -> {rmse_after:.2f}'
                    self._log.debug(msg)
                if progress:
                    pbar.set_description(msg)
                    pbar.update(1)
    return params

make_loader

make_loader(current_loader: Callable)

Create a new loader that applies colinearization transformation.

Source code in calibrie/colinearization.py
def make_loader(self, current_loader: Callable):
    """Create a new loader that applies colinearization transformation."""
    old_loader = deepcopy(current_loader)

    def loader(data: Any, column_order: Optional[List[str]] = None) -> pd.DataFrame:
        df = old_loader(data, None)  # Pass None to get all columns

        if df is not None and self._params is not None:
            self._log.debug(f"Applying colinearization to data of shape {df.shape}")
            print(f"Applying colinearization to data of shape {df.shape}")

            # Get the list of channels that should be transformed
            channels_to_transform = self._channel_names

            # Get indices of columns to transform
            transform_indices = [
                i for i, col in enumerate(df.columns) if col in channels_to_transform
            ]

            # Only transform the relevant columns
            values_to_transform = df.iloc[:, transform_indices].values
            self._log.debug(f"Values shape before transformation: {values_to_transform.shape}")
            transformed_values = self.compute_yprime_from_list(values_to_transform.T).T
            self._log.debug(f"Values shape after transformation: {transformed_values.shape}")

            # Create new dataframe with all original data
            transformed_df = df.copy()
            # Update only the transformed columns
            for idx, col_idx in enumerate(transform_indices):
                transformed_df.iloc[:, col_idx] = transformed_values[:, idx]

            # Filter to requested columns if specified
            if column_order is not None:
                try:
                    return transformed_df[column_order]
                except KeyError as e:
                    raise ValueError(
                        f"When loading cells, asked for unknown column {e}. Available: {list(transformed_df.columns)}"
                    )

            return transformed_df

        return df

    return loader

diagnostics

diagnostics(ctx: Context, **kwargs) -> Optional[List[DiagnosticFigure]]

Generate diagnostic visualizations showing before/after distributions.

Source code in calibrie/colinearization.py
def diagnostics(self, ctx: Context, **kwargs) -> Optional[List[DiagnosticFigure]]:
    """Generate diagnostic visualizations showing before/after distributions."""
    saturation_thresholds = ctx.saturation_thresholds
    channel_names = ctx.channel_names

    assert self._params is not None, "You need to run initialize() first"

    intervals = vmap(jnp.linspace, (0, 0, None))(
        saturation_thresholds[:, 0], saturation_thresholds[:, 1], 500
    ).T
    interval_reg = self.compute_yprime_from_list(intervals.T, self._params).T

    fig = plt.figure(figsize=(len(channel_names) * 8, 12))
    gs = plt.GridSpec(1, len(channel_names), figure=fig)

    for i in range(len(channel_names)):
        # Top plot: transformation curve
        ax_top = fig.add_subplot(gs[0, i])

        # Plot transformation curves (top)
        ax_top.plot(intervals[:, i], interval_reg[:, i], c='r', lw=2, label='Transformation')
        ax_top.plot(
            [intervals[:, i].min(), intervals[:, i].max()],
            [intervals[:, i].min(), intervals[:, i].max()],
            'k--',
            alpha=0.5,
            label='Identity',
        )

        # Add saturation thresholds
        for threshold in self._corrected_saturation_thresholds[i]:
            ax_top.axvline(threshold, color='orange', linestyle='--', alpha=0.5)

        # Style plots
        for ax in [ax_top]:
            ax.set_xlabel(channel_names[i])
            if self.logspace:
                ax.set_xscale('symlog', linthresh=50)
            plots.remove_topright_spines(ax)
            ax.legend(fontsize='small')

        ax_top.set_ylabel(f"{channel_names[i]}' (transformed)")

    fig.suptitle("Channel Transformations and Distributions", y=1.02)
    fig.tight_layout()

    figs = [DiagnosticFigure(fig=fig, name="Channel Transformations and Distributions")]

    # Add linearization path visualization
    if hasattr(self, '_absolute_ranges') and hasattr(self, '_relative_ranges'):
        path_fig = self._visualize_linearization_path()
        if path_fig:
            figs.append(DiagnosticFigure(fig=path_fig, name="Linearization Path"))

    return figs