diff --git a/changelog/134.bugfix.rst b/changelog/134.bugfix.rst new file mode 100644 index 0000000..39c3e88 --- /dev/null +++ b/changelog/134.bugfix.rst @@ -0,0 +1 @@ +Fix some inconsistencies on the data units in plots, also return handle to the image from `~stixpy.product.sources.science.SpectrogramPlotMixin.plot_spectrogram`. diff --git a/stixpy/product/sources/science.py b/stixpy/product/sources/science.py index 9a17fc1..1f535f0 100644 --- a/stixpy/product/sources/science.py +++ b/stixpy/product/sources/science.py @@ -210,8 +210,11 @@ def plot_spectrogram( counts, errors, times, timedeltas, energies = self.get_data( detector_indices=did, pixel_indices=pid, time_indices=time_indices, energy_indices=energy_indices ) + counts = counts.to(u.ct / u.s / u.keV) + errors = errors.to(u.ct / u.s / u.keV) + timedeltas = timedeltas.to(u.s) - e_edges = np.hstack([energies["e_low"], energies["e_high"][-1]]) + e_edges = np.hstack([energies["e_low"], energies["e_high"][-1]]).value t_edges = Time( np.concatenate([times - timedeltas.reshape(-1) / 2, times[-1] + timedeltas.reshape(-1)[-1:] / 2]) ) @@ -234,8 +237,12 @@ def plot_spectrogram( axes.xaxis.set_major_formatter(DateFormatter("%d %H:%M")) # fig.autofmt_xdate() # fig.tight_layout() + for i in plt.get_fignums(): + if axes in plt.figure(i).axes: + plt.sca(axes) + plt.sci(im) - return axes + return im class TimesSeriesPlotMixin: @@ -303,8 +310,11 @@ def plot_timeseries( time_indices=time_indices, energy_indices=energy_indices, ) + counts = counts.to(u.ct / u.s / u.keV) + errors = errors.to(u.ct / u.s / u.keV) + timedeltas = timedeltas.to(u.s) - labels = [f"{el.value} - {eh.value}" for el, eh in energies["e_low", "e_high"]] + labels = [f"{el.value} - {eh.value} keV" for el, eh in energies["e_low", "e_high"]] n_time, n_det, n_pix, n_energy = counts.shape @@ -363,6 +373,10 @@ def plot_pixels(self, time_indices=None, energy_indices=None, fig=None): counts, count_err, times, dt, energies = self.get_data(time_indices=time_indices, energy_indices=energy_indices) + counts = counts.to(u.ct / u.s / u.keV) + count_err = count_err.to(u.ct / u.s / u.keV) + dt = dt.to(u.s) + def timeval(val): return times[val].isot