'Differentiation of an improper integral using JAX and SciPy
I provide a simple code example of a failed attempt to use JAX
to automatically differentiate through an improper integral function making use of SciPy
's quad()
method. The function I consider is
with gradient given by
The following code is able to evaluate the function, but JAX
throws ConcretizationTypeError
error when I try to compute the gradient. What is the problem here and how can it be fixed?
import jax
from scipy.integrate import quad
## Function
def F(c1, c2):
val, err = quad(lambda x: c1/(1.0 + x**2), a=c2, b=jax.numpy.inf)
return val
## Gradient
grad_F = jax.grad(F)
## Parameters
c1 = -1.0
c2 = 0.0
## Evaluates function
F(c1, c2)
# -1.5707963267948966 (which is -pi/2 btw)
## Evaluates gradient
grad_F(c1, c2)
Throws:
---------------------------------------------------------------------------
ConcretizationTypeError Traceback (most recent call last)
/tmp/ipykernel_446012/1229440296.py in <module>
----> 1 grad_F(c1, c2)
[... skipping hidden 9 frame]
/tmp/ipykernel_446012/2999885932.py in F(c1, c2)
5 def F(c1, c2):
6 #val, err = jax.numpy.array(quad(lambda y: b/(1.0+y**2), a=a, b=jax.numpy.inf), float)
----> 7 val, err = quad(lambda x: c1/(1.0 + x**2), a=c2, b=jax.numpy.inf)
8 return val
9
~/anaconda3/lib/python3.8/site-packages/scipy/integrate/quadpack.py in quad(func, a, b, args, full_output, epsabs, epsrel, limit, points, weight, wvar, wopts, maxp1, limlst)
349
350 if weight is None:
--> 351 retval = _quad(func, a, b, args, full_output, epsabs, epsrel, limit,
352 points)
353 else:
~/anaconda3/lib/python3.8/site-packages/scipy/integrate/quadpack.py in _quad(func, a, b, args, full_output, epsabs, epsrel, limit, points)
463 return _quadpack._qagse(func,a,b,args,full_output,epsabs,epsrel,limit)
464 else:
--> 465 return _quadpack._qagie(func,bound,infbounds,args,full_output,epsabs,epsrel,limit)
466 else:
467 if infbounds != 0:
[... skipping hidden 1 frame]
~/anaconda3/lib/python3.8/site-packages/jax/core.py in error(self, arg)
998 f"or `jnp.array(x, {fun.__name__})` instead.")
999 def error(self, arg):
-> 1000 raise ConcretizationTypeError(arg, fname_context)
1001 return error
1002
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ConcreteArray(-0.5, dtype=float32)>with<JVPTrace(level=2/0)> with
primal = DeviceArray(-0.5, dtype=float32, weak_type=True)
tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[], weak_type=True), *)
recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7fbf4f643b90>, invars=(Traced<ConcreteArray(2.0, dtype=float32):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True):JaxprTrace(level=1/0)>), outvars=[<weakref at 0x7fbf4c402c20; to 'JaxprTracer' at 0x7fbf4eca5090>], primitive=xla_call, params={'device': None, 'backend': None, 'name': 'jvp(true_divide)', 'donated_invars': (False, False), 'inline': True, 'call_jaxpr': { lambda ; a:f32[] b:f32[]. let c:f32[] = div b a in (c,) }}, source_info=<jaxlib.xla_extension.Traceback object at 0x7fbf4eca1bb0>)
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
Solution 1:[1]
The issue is that JAX's grad
transform can only operate on functions made up entirely of JAX operations, and scipy.integrate.quad
is not a JAX operation. If you'd like to do this kind of computation, you'll have to find a JAX implementation of quad
.
Solution 2:[2]
A simple, albeit less accurate solution, would be to calculate your integral using the trapezoidal integral. You could also use some sort of quadrature rule which is essentially what scipy.integrate.quad does.
After setting JAX_ENABLE_X64=True on your local system and running the following script
import jax.numpy as jnp
from jax import grad
from scipy.special import roots_laguerre
def F(c1, c2, tmax, N):
"""
Inputs:
c1, c2 are the coefficients in the integral
tmax is the upper limit of the integral
N is the number of points. Can be low with a higher-order integration scheme
Output: scalar integral value
"""
t0 = jnp.linspace(c2, tmax, N)
return jnp.trapz(c1/(1+t0**2), t0)
c1 = -1.0
c2 = 0.
gradF = grad(F, (0,1))
dFc1, dFc2 = gradF(c1, c2, 15000, 30000)
print("error in grad1 = %.8f"%(dFc1.item() - (jnp.pi/2-jnp.arctan(c2))))
print("error in grad2 = %.8f"%(dFc2.item() + (c1/(1+c2**2))))
def GLF(c1, c2, xi, wi):
"""
Calculating the integral using Gauss-Laguerre quadrature
"""
return jnp.sum(wi*jnp.exp(xi)*c1/(1+(xi+c2)**2))
points, weights = roots_laguerre(100)
gradF_GL = grad(GLF, (0,1))
dF_GLc1, dF_GLc2 = gradF_GL(c1, c2, points, weights)
print("GL error in grad1 = %.8f"%(dF_GLc1.item() - (jnp.pi/2-jnp.arctan(c2))))
print("GL error in grad2 = %.8f"%(dF_GLc2.item() + (c1/(1+c2**2))))
you should be able to get
error in grad1 = -0.00005571
error in grad2 = -0.04435208
GL error in grad1 = -0.00258033
GL error in grad2 = -0.00000665
You can use the trapezoidal rule for the first gradient and the quadrature rule for the second one.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | jakevdp |
Solution 2 |