Introduction
In this post, I’ll show how JAX’s automatic differentiation makes it straightforward to train ’non-standard’ regression models. We’ll implement a Zero-Inflated Generalised Poisson (ZIGP) regression model. This would be tricky to implement from scratch but becomes manageable when we can just formulate the model we are interested in, write down the likelihood and let JAX handle the gradients.
The example comes from the paper “Zero-Inflated Generalized Poisson Regression Model with an Application to Domestic Violence Data” (apologies for the unpleasant subject matter, it just so happened that such a model had been used to describe this data, and the data had been made public which made for a nice reproduction). The model is interesting because we have to deal with both over-dispersion and a zero-inflation component. The inclusion of both of which, as we will see, are important.
A jupyter notebook version of the code in this post is available [5].
The Problem: Count Data with Structural Zeros
The dataset contains incident counts, , from households, along with covariates, , which are answers to a series of survey questions which the surveyors presumably thought were likely to be correlated with the rate of incidence.
This type of data exhibits three key features that make standard Poisson regression inadequate:
- Structural zeros: Many survey responses have , more than would be expected from a standard count distribution. Some of these zeros are ‘structural’ (households where incidents would never occur) but some are ‘sampling’ zeros ( was possible in principle but this was not observed).
- Over-dispersion: The variance in incident counts substantially exceeds the mean, violating the Poisson assumption that mean equals variance.
- Regression structure: We have covariates for both the the zero-inflation process and the ‘base’ count process (Poisson, Generalised Poisson, …).
We can inspect the interactions of each feature and the dependent variable to get a rough idea of what the relationships might look like:

The ZIGP Model
The Zero-Inflated Generalised Poisson (ZIGP) model meets, exactly, the demands of this dataset. It models count data, which may or may not have experience zero inflation (ZI), using a Generalised Poisson distribution (GP).
The Generalised Poisson Distribution
The standard Poisson distribution assumes but this is often violated in practice. The Generalised Poisson relaxes this constraint by introducing a dispersion parameter , giving a PMF of:
with mean . When , the distribution exhibits over-dispersion (), while gives under-dispersion. As , we recover the standard Poisson distribution.
Zero-Inflation Component
To handle structural zeros, we add a mixture component using a Bernoulli model for the probability that observation is a structural zero. Although there are other links which could be experimented with, the logistic link is natural here and provides a smooth, interpretable mapping from covariates to probabilities.
Full ZIGP Specification
Combining these ideas we have, for observation , the following PMF:
where is the probability of a structural zero, is the Poisson rate parameter and is the dispersion parameter.
Now we follow the regression structure from [1]:
where .
The authors make the choice to couple the Generalised Poisson mean, , and the zero-inflation parameter, , by introducing a parameter , and taking . Intuitively this makes some sense: the zero inflation increases as the mean decreases. However, it is not entirely clear to me that this is the optimal because we do not know that structural zeros are being ‘generated’ by the same process as the sampled zeros…Indeed I thought that was sort of the crux of the model! Possibly the authors tried using uncoupled regression designs for the Generalised Poisson component and the Zero Inflation component and found that either it didn’t converge well or that it wasn’t possible to derive the update steps required in this case. Either way, in what follows I will follow in their footsteps with the aim of reproducing their results.
Implementation
This is one of the reasons JAX is so useful. Instead of deriving the gradient by hand (which can be done, as in [1]), we simply implement the negative log-likelihood:
| |
The implementation follows the mathematics closely:
- We unpack into (regression coefficients), (zero-inflation scaling), and (transformed dispersion parameter)
- Compute and
- Transform via softplus to ensure
- Calculate the rate parameter
With that in hand, training the model is straightforward: We simply use scipys implementation of L-BFGS-B with JAX’s autograd providing the gradient.
| |
We can also train two simpler (nested) models for comparison:
- a standard Poisson (by fixing , effectively setting and ) and
- a Generalised Poisson without zero-inflation by fixing only .
Results
At convergence the model has the following parameter estimates:
| Parameter | Estimate | Parameter | Estimate |
|---|---|---|---|
| Intercept | 5.42828 | Fam_v | 0.180303 |
| Edu_b | 0.589901 | Club_b | -1.98673 |
| Edu_v | -1.5003 | Club_v | 1.71529 |
| Emp_b | 1.24372 | Drug_b | 1.5428 |
| Emp_v | 0.342406 | Drug_v | -1.06352 |
| Inc_b | -0.416757 | Tau | -0.124315 |
| Inc_v | -0.481259 | Gamma | -1.03127 |
| Fam_b | -0.662489 |
Which matches the ZIGP coefficients from [1] well.
Approximate Parameter Uncertainty
To perform inference on the results of the study we need uncertainties on these point-values. Drug_b looks to be quite strongly correlated with incidence, but if we were told that the error on this value is it changes the picture completely.
We can get these uncertainty bounds approximately, even for a reasonably fancy model like the ZIGP, for free with the code we’ve already written.
This is a fundamental result in asymptotic statistics [3]. Let be the log-likelihood. The score function is:
Taylor expanding the score around the true parameter up to the second order term
where is the Hessian. The linear term is zero because at the MLE, , we have . Rearranging this:
Under regularity conditions:
- The score satisfies a central limit theorem: , where is the Fisher information
- The Hessian converges:
Therefore:
In practice, we replace the expected Fisher information with the observed information , giving us:
Since we’re minimizing the negative log-likelihood, our Hessian is already positive, so .
JAX’s automatic differentiation makes it straightforward to compute approximate standard errors using the Hessian:
| |
| Parameter | Estimate | Std Error | 95% CI Lower | 95% CI Upper | p-value | Sig |
|---|---|---|---|---|---|---|
| Gamma (γ) | -1.031 | 0.211 | -1.445 | -0.617 | < 0.001 | *** |
| Intercept | 5.428 | 1.260 | 2.958 | 7.898 | < 0.001 | *** |
| Drug_b | 1.543 | 0.402 | 0.755 | 2.331 | < 0.001 | *** |
| Edu_v | -1.500 | 0.496 | -2.473 | -0.527 | 0.003 | ** |
| Club_b | -1.987 | 0.712 | -3.383 | -0.590 | 0.005 | ** |
| Drug_v | -1.169 | 0.481 | -2.112 | -0.226 | 0.015 | * |
| Inc_v | 0.874 | 0.377 | 0.135 | 1.613 | 0.020 | * |
Again, these match the results from [1] well.
Zero-inflation is significant: The parameter is highly significant (p < 0.001), which we interpret as the zero-inflation component contributing towards the model.
With , we get , so there is a reasonable degree of over-dispersion in the data. The lower bound on is so it is also ‘significant’ for over-dispersion.
Drug problems strongly predict incidence: Having drug problems at baseline (Drug_b) is the strongest predictor, associated with exp(1.54) ≈ 4.7× higher incident rates.
Education and social engagement matter: Higher education at visit (Edu_v) and club membership at baseline (Club_b) are both associated with significantly lower incidence rates.
Income shows positive association at visit: Contrary to expectations, income at visit (Inc_v) shows a positive association with incident counts, though this could reflect reporting bias or other confounding.
Model Comparison
To validate that the ZIGP model is truly necessary, we compare it against two simpler ’nested’ models using information criteria and likelihood ratio tests:
| Model | Params | NLL | AIC | BIC | Δ_AIC | Δ_BIC |
|---|---|---|---|---|---|---|
| Poisson | 13 | 2212.43 | 4450.87 | 4494.63 | 3689.18 | 3682.45 |
| GenPoisson | 14 | 431.32 | 890.64 | 937.76 | 128.95 | 125.58 |
| ZIGP | 15 | 365.84 | 761.69 | 812.18 | 0.0 | 0.0 |
The ZIGP model outperforms both alternatives. The standard Poisson model looks misspecified (Δ_AIC > 3500), while even the Generalised Poisson—which handles over-dispersion—is substantially worse than ZIGP (Δ_AIC ≈ 129).
Likelihood ratio tests confirm these findings:
Poisson vs Generalised Poisson: LR = 3562.23, p < 0.001 → Reject Poisson; over-dispersion is real
Generalised Poisson vs ZIGP: LR = 130.95, p < 0.001 → Reject GenPoisson; zero-inflation is necessary
Poisson vs ZIGP: LR = 3693.18, p < 0.001 → Strongly reject simplest model
All tests overwhelmingly reject the simpler models, confirming that both over-dispersion and zero-inflation are essential features of this data.
Summary
In this post we implemented a Zero-Inflated Generalised Poisson regression model to analyse domestic violence incident data. We were able to reproduce the results found by the authors of [1], but without having to worry about manually compute gradients of our arbitrarily complicated likelihood.
Similarly we saw that JAX can provide Hessians for our models in a single line of code, enabling us to provide approximate uncertainty bounds for our regression coefficients.
The combination of mathematical flexibility and computational efficiency makes custom model development accessible to anyone comfortable with numpy…sort of - JAX does have some self-confessed “sharp bits” to avoid [4]!
References
[1] Famoye, F., & Singh, K. P. (2006). Zero-Inflated Generalized Poisson Regression Model with an Application to Domestic Violence Data . Journal of Data Science, 4(1), 117-130.
[2] JAX: Composable transformations of Python+NumPy programs . Google Research.
[3] Casella, G., & Berger, R. L. (2002). Statistical Inference (2nd ed.). Duxbury Press. Chapter 10: Asymptotic Evaluations.
[5] Notebook