def diagnostics(self, ctx: Context, **kwargs) -> Optional[List[DiagnosticFigure]]:
"""Generate diagnostic visualizations showing before/after distributions."""
saturation_thresholds = ctx.saturation_thresholds
channel_names = ctx.channel_names
assert self._params is not None, "You need to run initialize() first"
intervals = vmap(jnp.linspace, (0, 0, None))(
saturation_thresholds[:, 0], saturation_thresholds[:, 1], 500
).T
interval_reg = self.compute_yprime_from_list(intervals.T, self._params).T
fig = plt.figure(figsize=(len(channel_names) * 8, 12))
gs = plt.GridSpec(1, len(channel_names), figure=fig)
for i in range(len(channel_names)):
# Top plot: transformation curve
ax_top = fig.add_subplot(gs[0, i])
# Plot transformation curves (top)
ax_top.plot(intervals[:, i], interval_reg[:, i], c='r', lw=2, label='Transformation')
ax_top.plot(
[intervals[:, i].min(), intervals[:, i].max()],
[intervals[:, i].min(), intervals[:, i].max()],
'k--',
alpha=0.5,
label='Identity',
)
# Add saturation thresholds
for threshold in self._corrected_saturation_thresholds[i]:
ax_top.axvline(threshold, color='orange', linestyle='--', alpha=0.5)
# Style plots
for ax in [ax_top]:
ax.set_xlabel(channel_names[i])
if self.logspace:
ax.set_xscale('symlog', linthresh=50)
plots.remove_topright_spines(ax)
ax.legend(fontsize='small')
ax_top.set_ylabel(f"{channel_names[i]}' (transformed)")
fig.suptitle("Channel Transformations and Distributions", y=1.02)
fig.tight_layout()
figs = [DiagnosticFigure(fig=fig, name="Channel Transformations and Distributions")]
# Add linearization path visualization
if hasattr(self, '_absolute_ranges') and hasattr(self, '_relative_ranges'):
path_fig = self._visualize_linearization_path()
if path_fig:
figs.append(DiagnosticFigure(fig=path_fig, name="Linearization Path"))
return figs