Skip to content

Commit

Permalink
merging
Browse files Browse the repository at this point in the history
  • Loading branch information
EiffL committed Dec 21, 2022
2 parents 4fc32c2 + 288ead9 commit 6ff4b5e
Showing 1 changed file with 72 additions and 25 deletions.
97 changes: 72 additions & 25 deletions jax_cosmo/probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,36 +75,70 @@ def integrand_single(z_prime):
ell_factor = np.sqrt((ell - 1) * (ell) * (ell + 1) * (ell + 2)) / (ell + 0.5) ** 2
return constant_factor * ell_factor * radial_kernel


@jit
def mag_kernel(cosmo, pzs, z, ell, s):
"""
Returns a magnification kernel
Needs magnification bias function
s = "logarithmic derivative of the number of sources with magnitude limit", a function valid for all z in z_prime
"""
z = np.atleast_1d(z)
zmax = max([pz.zmax for pz in pzs])
# Retrieve comoving distance corresponding to z
chi = bkgrd.radial_comoving_distance(cosmo, z2a(z))

@vmap
def integrand(z_prime):
chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime))
# Stack the dndz of all redshift bins
dndz = np.stack([pz(z_prime) for pz in pzs], axis=0)

mag_lim = (2.0 - 5.0 * s(cosmo, z_prime)) / 2.0

return dndz * np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0) * mag_lim

# Computes the radial weak lensing kernel
radial_kernel = np.squeeze(simps(integrand, z, zmax, 256) * (1.0 + z) * chi)
# Constant term (maybe one too many 2.0?)
constant_factor = 3.0 * const.H0**2 * cosmo.Omega_m / 2.0 / const.c / 2.0
# Ell dependent factor
ell_factor = ell * (ell + 1)
return constant_factor * ell_factor * radial_kernel


@jit
def mag_kernel(cosmo, pzs, z, ell, s):
"""
Returns a magnification kernel
Needs magnification bias function
Needs magnification bias function
s = "logarithmic derivative of the number of sources with magnitude limit", a function valid for all z in z_prime
"""
z = np.atleast_1d(z)
zmax = max([pz.zmax for pz in pzs])
# Retrieve comoving distance corresponding to z
chi = bkgrd.radial_comoving_distance(cosmo, z2a(z))

@vmap
def integrand(z_prime):
chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime))
# Stack the dndz of all redshift bins
dndz = np.stack([pz(z_prime) for pz in pzs], axis=0)
mag_lim = (2.0-5.0*s(cosmo, z_prime))/2.0
return dndz * np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0)*mag_lim

mag_lim = (2.0 - 5.0 * s(cosmo, z_prime)) / 2.0

return dndz * np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0) * mag_lim

# Computes the radial weak lensing kernel
radial_kernel = np.squeeze(simps(integrand, z, zmax, 256) * (1.0 + z) * chi)
# Constant term (maybe one too many 2.0?)
constant_factor = 3.0 * const.H0 ** 2 * cosmo.Omega_m / 2.0 / const.c / 2.0
constant_factor = 3.0 * const.H0**2 * cosmo.Omega_m / 2.0 / const.c / 2.0
# Ell dependent factor
ell_factor = ell*(ell+1)
ell_factor = ell * (ell + 1)
return constant_factor * ell_factor * radial_kernel


Expand Down Expand Up @@ -132,6 +166,7 @@ def density_kernel(cosmo, pzs, bias, z, ell):
ell_factor = 1.0
return constant_factor * ell_factor * radial_kernel


@jit
def nla_kernel(cosmo, pzs, bias, z, ell):
"""
Expand Down Expand Up @@ -167,25 +202,37 @@ def rsd_kernel(cosmo, pzs, z, ell, z1):
"""
Computes the RSD kernel
"""
print(z,z1)
print(z, z1)
# stack the dndz of all redshift bins
dndz = np.stack([pz(z) for pz in pzs], axis=0)

# Normalization,
constant_factor = 1.0
# Ell dependent factor
ell_factor1 = (1+8*ell)/((2*ell+1)**2.0)
ell_factor1 = (1 + 8 * ell) / ((2 * ell + 1) ** 2.0)
# stack the dndz of all redshift bins
dndz = np.stack([pz(z) for pz in pzs], axis=0)
radial_kernel1 = dndz * bkgrd.growth_rate(cosmo, z2a(z))/bkgrd.growth_factor(cosmo, z2a(z)) * bkgrd.H(cosmo, z2a(z))

radial_kernel1 = (
dndz
* bkgrd.growth_rate(cosmo, z2a(z))
/ bkgrd.growth_factor(cosmo, z2a(z))
* bkgrd.H(cosmo, z2a(z))
)

# Ell dependent factor
ell_factor2 = (4)/(2*ell+3) *np.sqrt((2*ell+1)/(2*ell+3))
ell_factor2 = (4) / (2 * ell + 3) * np.sqrt((2 * ell + 1) / (2 * ell + 3))
# stack the dndz of all redshift bins
dndz = np.stack([pz(z1) for pz in pzs], axis=0)
radial_kernel2 = dndz * bkgrd.growth_rate(cosmo, z2a(z1))/bkgrd.growth_factor(cosmo, z2a(z1)) * bkgrd.H(cosmo, z2a(z1))
radial_kernel2 = (
dndz
* bkgrd.growth_rate(cosmo, z2a(z1))
/ bkgrd.growth_factor(cosmo, z2a(z1))
* bkgrd.H(cosmo, z2a(z1))
)

return constant_factor*(ell_factor1 * radial_kernel1 - ell_factor2*radial_kernel2)
return constant_factor * (
ell_factor1 * radial_kernel1 - ell_factor2 * radial_kernel2
)


@register_pytree_node_class
Expand Down Expand Up @@ -294,11 +341,11 @@ class NumberCounts(container):
mag_bias....
"""

def __init__(self, redshift_bins, bias, has_rsd=False,mag_bias=False, **kwargs):
def __init__(self, redshift_bins, bias, has_rsd=False, mag_bias=False, **kwargs):
super(NumberCounts, self).__init__(
redshift_bins, bias, has_rsd=has_rsd,mag_bias=mag_bias, **kwargs
redshift_bins, bias, has_rsd=has_rsd, mag_bias=mag_bias, **kwargs
)
self.mag_bias =mag_bias
self.mag_bias = mag_bias
self.has_rsd = has_rsd

@property
Expand All @@ -317,7 +364,7 @@ def n_tracers(self):
pzs = self.params[0]
return len(pzs)

def kernel(self, cosmo, z, ell):
def kernel(self, cosmo, z, ell, z1):
"""Compute the radial kernel for all nz bins in this probe.
Returns:
Expand All @@ -329,13 +376,13 @@ def kernel(self, cosmo, z, ell):
pzs, bias = self.params
# Retrieve density kernel
kernel = density_kernel(cosmo, pzs, bias, z, ell)

if self.mag_bias:
kernel += mag_kernel(cosmo, pzs, z, ell, self.mag_bias)

if self.has_rsd:
kernel += rsd_kernel(cosmo, pzs, z, ell, z1)

return kernel

def noise(self):
Expand Down

0 comments on commit 6ff4b5e

Please sign in to comment.