Introduction

Like most good things, this started off with a serendipitous conversation during a ski trip. I was asking my biostats PhD friend about their epidemiology lab. They explained the challenges related to how much data was captured by health systems and the modeling techniques used to extract conclusions from that data. I was then introduced to distributed lag non-linear models (DLNM). I was curious if there were some low hanging fruit from an engineering side that would help and excited by a potential opportunity to contribute to open source software. The rest of this post breaks down what this library does, potential optimizations and benchmarks for evaluation.

DLNM

This model is used in epidemiology studies. It takes a dataset like the example below and tries to answer the question, given the temperature has been hot in the last five days, does the number of deaths go up? The challenge here is that we need to consider all temperatures in the past to a certain look back date and also understand that there could be a non-linear relationship between the temperatures and the look back date. For example, the temp going up 20 degrees from 100 to 120 is a lot more impactful than going from 50 to 70. It could also take several days for the effect to be found. Both of these are non-linear relationships that DLNM tries to capture.

date time year month doy dow death temp pm10
1987-01-01 1 1987 1 1 Thursday 130 -0.28 26.96
1987-01-02 2 1987 1 2 Friday 150 0.56 NA
1987-01-03 3 1987 1 3 Saturday 101 0.56 32.84

In the DLNM library there are two main steps. Transforming the data into a "crossbasis matrix" and fitting the data using R's glm(). Generating the crossbasis matrix involves first decomposing the dependent variable like temp and the lag date into several basis functions. Then for a given row you take the tensor product of the corresponding basis vectors. For example, 2 temp basis functions and 2 lag basis functions and a 2 day look back window. Then for every day row you have 3 lagged temps, each of those temps gets broken down into 2 temp basis functions and then for every lag (0 to 2) you have 2 lag basis functions. Then you take a dot product for every temp basis function and lag basis function to construct the cross basis matrix. And then fit your generalized linear model on that.

crossbasis <- matrix(0, nrow=dim[1], ncol=ncol(basisvar)*ncol(basislag))
for(v in seq(length=ncol(basisvar))) {
    mat <- as.matrix(Lag(basisvar[, v], seqlag(lag)))
    for(l in seq(length=ncol(basislag))) {
      crossbasis[, ncol(basislag)*(v-1)+l] <- mat %*% basislag[,l]
    }
}
day temp death
1 -2° 130
2 142
3 12° 118
4 20° 105
5 28° 110
6 15° 115
7 135
8 148

Each temp is fed through 2 functions → 2 columns.

day temp B1(temp) B2(temp)
1 -2° 0.05 0.90
2 0.22 0.70
3 12° 0.48 0.42
4 20° 0.72 0.18
5 28° 0.91 0.04
6 15° 0.55 0.38
7 0.30 0.62
8 0.15 0.78

Shift each basis column back by 0, 1, and 2 days. Days 1–2 lose rows (not enough history).

B1 lagged → mat_B1 (6 rows x 3 lags)

lag0 lag1 lag2
day 3 0.48 0.22 0.05
day 4 0.72 0.48 0.22
day 5 0.91 0.72 0.48
day 6 0.55 0.91 0.72
day 7 0.30 0.55 0.91
day 8 0.15 0.30 0.55

B2 lagged → mat_B2 (6 rows x 3 lags)

lag0 lag1 lag2
day 3 0.42 0.70 0.90
day 4 0.18 0.42 0.70
day 5 0.04 0.18 0.42
day 6 0.38 0.04 0.18
day 7 0.62 0.38 0.04
day 8 0.78 0.62 0.38

Step 4 — crossbasis (the final matrix fed to the GLM)

2 temp basis x 2 lag basis = 4 columns. Each cell = dot product of a lagged row with a lag basis column.

death v1.l1 v1.l2 v2.l1 v2.l2
day 3 118 0.464 0.203 0.626 1.202
day 4 105 0.755 0.510 0.326 0.858
day 5 110 1.004 0.883 0.116 0.472
day 6 115 0.795 1.158 0.327 0.220
day 7 135 0.478 1.124 0.631 0.288
day 8 148 0.253 0.660 0.860 0.730

Optimizations

The main bottleneck is that your data set grows by N * lag days * variable basis functions * lag basis functions when you just want N * varbasis function * lag basis functions. If you have a 1GB data set then this already maxes out most consumer hardware. In order to avoid this we just use index arithmetic to pull previous days values. We also use the rayon library to split up the n rows across all cores and process them in parallel.

for (v in 1:ncol(basisvar)) {
  lag_matrix <- Lag(basisvar[, v], lag_range)  # Materializes (n x lag_range) matrix
  for (l in 1:ncol(basislag)) {
    cb[, idx] <- lag_matrix %*% basislag[, l]
    idx <- idx + 1
  }
}
// pseudocode
// parallelized across rows (days)
  for each row i (in parallel):
      for v in 0..nv:
          // gather 22 lagged values directly from basisvar column
          var_vals[j] = basisvar[i - j, v]   // ← no lag matrix, just index arithmetic

          for l in 0..nl:
              // dot product in-place
              sum += var_vals[j] * basislag[j, l]

Benchmarks

Config Var Basis Lag Basis Lag Range CB Columns Purpose
C1 lin (df=1) poly(degree=4) (df=5) 0-15 5 Minimal/DLM
C2 ns(df=5) ns(df=4) 0-21 20 Typical epidemiology
Scale Cities Rows
10 MB 24 ~123K
100 MB 235 ~1.2M
1 GB 2,350 ~12M
10 GB 23,500 ~120M
Scale Config R Baseline (sec) Rust Optimized (sec)
10MB C1 0.032 0.034
10MB C2 0.215 0.061
100MB C1 0.382 0.165
100MB C2 2.147 0.630
1GB C1 4.528 0.957
1GB C2 41.852 7.399
10GB C1 18.349
Config Stage R Baseline (sec) Rust Optimized (sec)
C1 crossbasis 0.382 0.165
C1 glm 14.598 21.094
C1 crosspred 0.103 0.139
C1 crossreduce 0.111 0.118
Config Stage R Baseline (sec) Rust Optimized (sec)
C2 crossbasis 2.147 0.630
C2 glm 17.667 24.312
C2 crosspred 0.105 0.178
C2 crossreduce 0.107 0.157

Conclusions

From the benchmarks it seems like our Rust implementation accomplished what it was meant to do, when the dataset gets large it allows the cross-basis matrix to still be formed within my 32 GB memory limits. It also computes it a lot faster than the R implementation. I chose to do this in Rust arbitrarily, which seems like the wrong choice given R's FFI only supports C and C++. I ended up using the extendr crate in order to bridge that interface.

For future work it seems like the GLM function, which is native to R, also takes up the majority of the run time. There is a fastGLM library that we could use as a drop-in to see if that would run faster as well.

There are also optimizations that we could do in terms of fitting the data into the cache optimally. Given we are doing an index lookup of lag_num values making sure that we hit L1 cache could also be a place to optimize.

Reflections

This is more of an exercise in seeing if I could add value to a domain that I am new to. After cleaning up this repo, I'll forward it to the original author and see if we can push an update to this library. I'm interested in hearing any suggestions on better ways to optimize this library. I'm also on the look out similar libraries in the academic open source software space that could be looked at.