import matplotlib.pyplot as plt
from fastddm.weights import sector_average_weight

shape = (128, 128)

# plot setup
fig = plt.figure(figsize=(5, 5))
gs = fig.add_gridspec(ncols=2, nrows=2)
axs = gs.subplots(sharex=True, sharey=True)
((ax0, ax1), (ax2, ax3)) = axs

# top left
im0 = ax0.imshow(sector_average_weight(shape))
cb = plt.colorbar(im0)
ax0.set_axis_off()
ax0.set_title('default parameters')

# top right
im1 = ax1.imshow(sector_average_weight(shape, theta_0=45, delta_theta=45, rep=4))
cb = plt.colorbar(im1)
ax1.set_axis_off()
ax1.set_title(r"$\theta_0=45,\ \Delta \theta=45,\ \mathrm{rep}=4$")

# bottom left
im2 = ax2.imshow(sector_average_weight(shape, theta_0=90, delta_theta=140, rep=2))
cb = plt.colorbar(im2)
ax2.set_axis_off()
ax2.set_title(r"$\theta_0=90,\ \Delta \theta=140,\ \mathrm{rep}=2$")

# bottom right
im3 = ax3.imshow(sector_average_weight(shape, theta_0=90, delta_theta=140, rep=2, kind="gauss"))
cb = plt.colorbar(im3)
ax3.set_axis_off()
ax3.set_title(
    (
        r"$\theta_0=45,\ \Delta \theta=45,$" + "\n"
        "$\mathrm{rep}=4,\ \mathrm{kind}=\mathrm{gauss}$"
    )
)

# displaying
fig.tight_layout()
plt.show()