'Differentiation of an improper integral using JAX and SciPy

I provide a simple code example of a failed attempt to use JAX to automatically differentiate through an improper integral function making use of SciPy's quad() method. The function I consider is

Function

with gradient given by

Gradient

The following code is able to evaluate the function, but JAX throws ConcretizationTypeError error when I try to compute the gradient. What is the problem here and how can it be fixed?

import jax
from scipy.integrate import quad

## Function
def F(c1, c2):
    val, err = quad(lambda x: c1/(1.0 + x**2), a=c2, b=jax.numpy.inf)
    return val

## Gradient
grad_F = jax.grad(F)

## Parameters
c1 = -1.0
c2 = 0.0

## Evaluates function
F(c1, c2)
# -1.5707963267948966   (which is -pi/2 btw)

## Evaluates gradient
grad_F(c1, c2)

Throws:

---------------------------------------------------------------------------
ConcretizationTypeError                   Traceback (most recent call last)
/tmp/ipykernel_446012/1229440296.py in <module>
----> 1 grad_F(c1, c2)

    [... skipping hidden 9 frame]

/tmp/ipykernel_446012/2999885932.py in F(c1, c2)
      5 def F(c1, c2):
      6     #val, err = jax.numpy.array(quad(lambda y: b/(1.0+y**2), a=a, b=jax.numpy.inf), float)
----> 7     val, err = quad(lambda x: c1/(1.0 + x**2), a=c2, b=jax.numpy.inf)
      8     return val
      9 

~/anaconda3/lib/python3.8/site-packages/scipy/integrate/quadpack.py in quad(func, a, b, args, full_output, epsabs, epsrel, limit, points, weight, wvar, wopts, maxp1, limlst)
    349 
    350     if weight is None:
--> 351         retval = _quad(func, a, b, args, full_output, epsabs, epsrel, limit,
    352                        points)
    353     else:

~/anaconda3/lib/python3.8/site-packages/scipy/integrate/quadpack.py in _quad(func, a, b, args, full_output, epsabs, epsrel, limit, points)
    463             return _quadpack._qagse(func,a,b,args,full_output,epsabs,epsrel,limit)
    464         else:
--> 465             return _quadpack._qagie(func,bound,infbounds,args,full_output,epsabs,epsrel,limit)
    466     else:
    467         if infbounds != 0:

    [... skipping hidden 1 frame]

~/anaconda3/lib/python3.8/site-packages/jax/core.py in error(self, arg)
    998                       f"or `jnp.array(x, {fun.__name__})` instead.")
    999   def error(self, arg):
-> 1000     raise ConcretizationTypeError(arg, fname_context)
   1001   return error
   1002 

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ConcreteArray(-0.5, dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = DeviceArray(-0.5, dtype=float32, weak_type=True)
  tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[], weak_type=True), *)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7fbf4f643b90>, invars=(Traced<ConcreteArray(2.0, dtype=float32):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True):JaxprTrace(level=1/0)>), outvars=[<weakref at 0x7fbf4c402c20; to 'JaxprTracer' at 0x7fbf4eca5090>], primitive=xla_call, params={'device': None, 'backend': None, 'name': 'jvp(true_divide)', 'donated_invars': (False, False), 'inline': True, 'call_jaxpr': { lambda ; a:f32[] b:f32[]. let c:f32[] = div b a in (c,) }}, source_info=<jaxlib.xla_extension.Traceback object at 0x7fbf4eca1bb0>)
The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError


Solution 1:[1]

The issue is that JAX's grad transform can only operate on functions made up entirely of JAX operations, and scipy.integrate.quad is not a JAX operation. If you'd like to do this kind of computation, you'll have to find a JAX implementation of quad.

Solution 2:[2]

A simple, albeit less accurate solution, would be to calculate your integral using the trapezoidal integral. You could also use some sort of quadrature rule which is essentially what scipy.integrate.quad does.

After setting JAX_ENABLE_X64=True on your local system and running the following script

import jax.numpy as jnp
from jax import grad

from scipy.special import roots_laguerre

def F(c1, c2, tmax, N):
    """
    Inputs:
    c1, c2  are the coefficients in the integral
    tmax is the upper limit of the integral
    N is the number of points. Can be low with a higher-order integration scheme
    Output: scalar integral value
    """
    t0 = jnp.linspace(c2, tmax, N)
    return jnp.trapz(c1/(1+t0**2), t0)

c1 = -1.0
c2 = 0.

gradF = grad(F, (0,1))
dFc1, dFc2   = gradF(c1, c2, 15000, 30000) 

print("error in grad1 = %.8f"%(dFc1.item() - (jnp.pi/2-jnp.arctan(c2))))
print("error in grad2 = %.8f"%(dFc2.item() + (c1/(1+c2**2))))



def GLF(c1, c2, xi, wi):
    """
    Calculating the integral using Gauss-Laguerre quadrature
    """
    return jnp.sum(wi*jnp.exp(xi)*c1/(1+(xi+c2)**2))


points, weights  = roots_laguerre(100)  

gradF_GL         = grad(GLF, (0,1))
dF_GLc1, dF_GLc2 = gradF_GL(c1, c2, points, weights) 

print("GL error in grad1 = %.8f"%(dF_GLc1.item() - (jnp.pi/2-jnp.arctan(c2))))
print("GL error in grad2 = %.8f"%(dF_GLc2.item() + (c1/(1+c2**2))))

you should be able to get

error in grad1 = -0.00005571
error in grad2 = -0.04435208
GL error in grad1 = -0.00258033
GL error in grad2 = -0.00000665

You can use the trapezoidal rule for the first gradient and the quadrature rule for the second one.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 jakevdp
Solution 2