Rif A. Saurous’s research while affiliated with Mountain View College and other places

What is this page?


This page lists works of an author who doesn't have a ResearchGate profile or hasn't added the works to their profile yet. It is automatically generated from public (personal) data to further our legitimate goal of comprehensive and accurate scientific recordkeeping. If you are this author and want this page removed, please let us know.

Publications (44)


MesaNet: Sequence Modeling by Locally Optimal Test-Time Training
  • Preprint
  • File available

June 2025

·

19 Reads

·

·

·

[...]

·

Sequence modeling is currently dominated by causal transformer architectures that use softmax self-attention. Although widely adopted, transformers require scaling memory and compute linearly during inference. A recent stream of work linearized the softmax operation, resulting in powerful recurrent neural network (RNN) models with constant memory and compute costs such as DeltaNet, Mamba or xLSTM. These models can be unified by noting that their recurrent layer dynamics can all be derived from an in-context regression objective, approximately optimized through an online learning rule. Here, we join this line of work and introduce a numerically stable, chunkwise parallelizable version of the recently proposed Mesa layer (von Oswald et al., 2024), and study it in language modeling at the billion-parameter scale. This layer again stems from an in-context loss, but which is now minimized to optimality at every time point using a fast conjugate gradient solver. Through an extensive suite of experiments, we show that optimal test-time training enables reaching lower language modeling perplexity and higher downstream benchmark performance than previous RNNs, especially on tasks requiring long context understanding. This performance gain comes at the cost of additional flops spent during inference time. Our results are therefore intriguingly related to recent trends of increasing test-time compute to improve performance -- here by spending compute to solve sequential optimization problems within the neural network itself.

Download

Probabilistic graphical model representation of the Bayesian Neural Field
a An example spatiotemporal domain comprised of two spatial coordinates (latitude, longitude) and a daily time coordinate. b In the probabilistic graphical model, each node denotes a model variable and each edge denotes a direct relationship between a pair of variables. Gray nodes are observed variables and white notes are local latent variables, which are both associated with an observation Y(s, t) at a spatiotemporal coordinate (s, t). Pink nodes are global latent variables (parameters), which are shared across all spatiotemporal coordinates. c Realizations of the spatiotemporal field generated from the BayesNF at four example time points. Satellite basemap source: Esri, DigitalGlobe, GeoEye, i-cubed, USDA FSA, USGS, AEX, Getmapping, Aerogrid, IGN, IGP, swisstopo, and the GIS User Community⁴⁸.
Spatial and temporal observations for evaluation datasets from Table 1
a Snapshots of spatial observations at fixed points in time. b Snapshots of temporal observations at fixed locations in space. Satellite basemap source: Ⓒ Stadia Maps, Ⓒ OpenMapTiles, Ⓒ OpenStreetMap, Ⓒ Stamen Design, Ⓒ CNES, Distribution Airbus DS, Ⓒ Airbus DS, Ⓒ PlanetObserver (Contains Copernicus Data)⁴⁹.
Comparison of predictions using BayesNF and various baselines
Each row shows results for a given spatiotemporal benchmark dataset at one spatial location. Black dots are observed data, blue dots are test data, red lines are median forecasts, and gray regions are 95% prediction intervals. (BayesNF: Bayesian Neural Field. Svgp: Spatiotemporal Sparse Variational Gaussian Process. Gboost: Spatiotemporal Gradient Boosting Trees. StGLMM: Spatiotemporal Generalized Linear Mixed Effect Models. NBEATS: Neural Basis Expansion Analysis. TSReg: Trend-Surface Regression).
Spatiotemporal prediction of atmospheric particulate matter (PM10) in German air dataset
a shows the observed data at four time points: each shaded circle represents a measurements of PM10 at a given station. Higher values of PM10 correspond to lower air quality. The data is sparse: at any given time point, only 47% of stations (on average) are associated with a PM10 observation. b Median predictions of PM10 air quality at four time points across the whole spatial field. c Width of 95% predictions of PM10 air quality at four time points across the whole spatial field. d Observed PM10 data (black) and median prediction (red) at four sparsely observed locations across time. Satellite basemap source: Ⓒ Stadia Maps, Ⓒ OpenMapTiles, Ⓒ OpenStreetMap, Ⓒ Stamen Design, Ⓒ CNES, Distribution Airbus DS, Ⓒ Airbus DS, Ⓒ PlanetObserver (Contains Copernicus Data)⁴⁹.
Comparison of the empirical and inferred spatiotemporal semivariograms, which measure the variance of the difference between field values at a pair of locations, for German PM10 air quality dataset
The empirical semivariogram is computed using the locations of the 70 stations in the observed dataset. The inferred semivariogram is computed on 70 novel spatial locations, sampled uniformly at random within the boundary of the field. a The agreement between the semivariogram surfaces indicates that BayesNF extrapolates the joint spatiotemporal dependence structure between locations in the observed data to novel locations. b For short time lags less than three days, the empirical variogram is higher than the inferred variogram at all distances, showing that BayesNF models high-frequency day-to-day variance as unpredictable observation noise.

+2

Scalable spatiotemporal prediction with Bayesian neural fields

September 2024

·

105 Reads

·

7 Citations

Spatiotemporal datasets, which consist of spatially-referenced time series, are ubiquitous in diverse applications, such as air pollution monitoring, disease tracking, and cloud-demand forecasting. As the scale of modern datasets increases, there is a growing need for statistical methods that are flexible enough to capture complex spatiotemporal dynamics and scalable enough to handle many observations. This article introduces the Bayesian Neural Field (BayesNF), a domain-general statistical model that infers rich spatiotemporal probability distributions for data-analysis tasks including forecasting, interpolation, and variography. BayesNF integrates a deep neural network architecture for high-capacity function estimation with hierarchical Bayesian inference for robust predictive uncertainty quantification. Evaluations against prominent baselines show that BayesNF delivers improvements on prediction problems from climate and public health data containing tens to hundreds of thousands of measurements. Accompanying the paper is an open-source software package (https://github.com/google/bayesnf) that runs on GPU and TPU accelerators through the Jax machine learning platform.


Sequential Monte Carlo Learning for Time Series Structure Discovery

July 2023

·

15 Reads

This paper presents a new approach to automatically discovering accurate models of complex time series data. Working within a Bayesian nonparametric prior over a symbolic space of Gaussian process time series models, we present a novel structure learning algorithm that integrates sequential Monte Carlo (SMC) and involutive MCMC for highly effective posterior inference. Our method can be used both in "online" settings, where new data is incorporated sequentially in time, and in "offline" settings, by using nested subsets of historical data to anneal the posterior. Empirical measurements on real-world time series show that our method can deliver 10x--100x runtime speedups over previous MCMC and greedy-search structure learning algorithms targeting the same model family. We use our method to perform the first large-scale evaluation of Gaussian process time series structure learning on a prominent benchmark of 1,428 econometric datasets. The results show that our method discovers sensible models that deliver more accurate point forecasts and interval forecasts over multiple horizons as compared to widely used statistical and neural baselines that struggle on this challenging data.


ProbNeRF: Uncertainty-Aware Inference of 3D Shapes from 2D Images

October 2022

·

52 Reads

·

1 Citation

The problem of inferring object shape from a single 2D image is underconstrained. Prior knowledge about what objects are plausible can help, but even given such prior knowledge there may still be uncertainty about the shapes of occluded parts of objects. Recently, conditional neural radiance field (NeRF) models have been developed that can learn to infer good point estimates of 3D models from single 2D images. The problem of inferring uncertainty estimates for these models has received less attention. In this work, we propose probabilistic NeRF (ProbNeRF), a model and inference strategy for learning probabilistic generative models of 3D objects' shapes and appearances, and for doing posterior inference to recover those properties from 2D images. ProbNeRF is trained as a variational autoencoder, but at test time we use Hamiltonian Monte Carlo (HMC) for inference. Given one or a few 2D images of an object (which may be partially occluded), ProbNeRF is able not only to accurately model the parts it sees, but also to propose realistic and diverse hypotheses about the parts it does not see. We show that key to the success of ProbNeRF are (i) a deterministic rendering scheme, (ii) an annealed-HMC strategy, (iii) a hypernetwork-based decoder architecture, and (iv) doing inference over a full set of NeRF weights, rather than just a low-dimensional code.


Language Model Cascades

July 2022

·

136 Reads

·

4 Citations

Prompted models have demonstrated impressive few-shot learning abilities. Repeated interactions at test-time with a single model, or the composition of multiple models together, further expands capabilities. These compositions are probabilistic models, and may be expressed in the language of graphical models with random variables whose values are complex data types such as strings. Cases with control flow and dynamic structure require techniques from probabilistic programming, which allow implementing disparate model structures and inference strategies in a unified language. We formalize several existing techniques from this perspective, including scratchpads / chain of thought, verifiers, STaR, selection-inference, and tool use. We refer to the resulting programs as language model cascades.


Beyond the Imitation Game: Quantifying and extrapolating the capabilities of language models

June 2022

·

1,018 Reads

·

76 Citations

Language models demonstrate both quantitative improvement and new qualitative capabilities with increasing scale. Despite their potentially transformative impact, these new capabilities are as yet poorly characterized. In order to inform future research, prepare for disruptive new model capabilities, and ameliorate socially harmful effects, it is vital that we understand the present and near-future capabilities and limitations of language models. To address this challenge, we introduce the Beyond the Imitation Game benchmark (BIG-bench). BIG-bench currently consists of 204 tasks, contributed by 442 authors across 132 institutions. Task topics are diverse, drawing problems from linguistics, childhood development, math, common-sense reasoning, biology, physics, social bias, software development, and beyond. BIG-bench focuses on tasks that are believed to be beyond the capabilities of current language models. We evaluate the behavior of OpenAI's GPT models, Google-internal dense transformer architectures, and Switch-style sparse transformers on BIG-bench, across model sizes spanning millions to hundreds of billions of parameters. In addition, a team of human expert raters performed all tasks in order to provide a strong baseline. Findings include: model performance and calibration both improve with scale, but are poor in absolute terms (and when compared with rater performance); performance is remarkably similar across model classes, though with benefits from sparsity; tasks that improve gradually and predictably commonly involve a large knowledge or memorization component, whereas tasks that exhibit "breakthrough" behavior at a critical scale often involve multiple steps or components, or brittle metrics; social bias typically increases with scale in settings with ambiguous context, but this can be improved with prompting.


FIGURE 1. Infection estimation overview. Top: the underlying infection time series-new infections per day-is perturbed by a delay distribution (center) that is measured with other data or assumed known. Each infection date is stochastically delayed, resulting in the reporting curve (red ×'s)-new reported cases per day. Bottom: the estimation procedure aims to undo this stochastic delay. Given observed report curve (left) we use a statistical estimator with the delay distribution to recover the underlying curve. (a) Assumed data generating process. (b) Estimation procedure.
FIGURE 2. Synthetic experiments. (A) Six synthetic incidence curves and simulated observations with both correct and misspecified noise models. (B) Estimators on the full slow-decay data. (C) Estimators on censored slow-decay data. (D-E) Root mean squared error over all experimental settings and replicates for both correct and misspecified data for all data (D) and over the most recent 20 days of observations (E). Additional details are in the eAppendix; http://links.lww.com/EDE/B924. AR indicates autoregressive; BP, back projection; Re-conv, reconvolution; RIDE, robust incidence deconvolution estimator; RL, Richardson-Lucy; and RMSE, root mean squared error.
FIGURE 3. (A) Infections incidence (solid black line), 90% credible regions (gray shaded) when available, and observed values (red plus) by data type across regions (rows) and by method (columns). (B) Estimated R t with 90% credible regions fit using the method of 1 from the incidence estimates in (A). Richardson-Lucy estimates are truncated in each panel for readability. RIDE indicates robust incidence deconvolution estimator.
FIGURE 5. R t fitted on infection time series estimated by data type across regions. Solid lines are means and ribbons are 90% credible regions.
Statistical Deconvolution for Inference of Infection Time Series

May 2022

·

64 Reads

·

22 Citations

Epidemiology

Accurate measurement of daily infection incidence is crucial to epidemic response. However, delays in symptom onset, testing, and reporting obscure the dynamics of transmission, necessitating methods to remove the effects of stochastic delays from observed data. Existing estimators can be sensitive to model misspecification and censored observations; many analysts have instead used methods that exhibit strong bias. We develop an estimator with a regularization scheme to cope with stochastic delays, which we term the robust incidence deconvolution estimator. We compare the method to existing estimators in a simulation study, measuring accuracy in a variety of experimental conditions. We then use the method to study COVID-19 records in the United States, highlighting its stability in the face of misspecification and right censoring. To implement the robust incidence deconvolution estimator, we release incidental, a ready-to-use R implementation of our estimator that can aid ongoing efforts to monitor the COVID-19 pandemic.


Statistical deconvolution for inference of infection time series

October 2020

·

62 Reads

·

4 Citations

Accurate measurement of daily infection incidence is crucial to epidemic response. However, delays in symptom onset, testing, and reporting obscure the dynamics of transmission, necessitating methods to remove the effects of stochastic delays from observed data. Existing estimators can be sensitive to model misspecification and censored observations; many analysts have instead used methods that exhibit strong bias or do not account for delays. We develop an estimator with a regularization scheme to cope with these sources of noise, which we term the Robust Incidence Deconvolution Estimator (RIDE). We validate RIDE on synthetic data, comparing accuracy and stability to existing approaches. We then use RIDE to study COVID-19 records in the United States, and find evidence that infection estimates from reported cases can be more informative than estimates from mortality data. To implement these methods, we release incidental , a ready-to-use R implementation of our estimator that can aid ongoing efforts to monitor the COVID-19 pandemic.


Estimating the Changing Infection Rate of COVID-19 Using Bayesian Models of Mobility

August 2020

·

90 Reads

·

1 Citation

In order to prepare for and control the continued spread of the COVID-19 pandemic while minimizing its economic impact, the world needs to be able to estimate and predict COVID-19's spread. Unfortunately, we cannot directly observe the prevalence or growth rate of COVID-19; these must be inferred using some kind of model. We propose a hierarchical Bayesian extension to the classic susceptible-exposed-infected-removed (SEIR) compartmental model that adds compartments to account for isolation and death and allows the infection rate to vary as a function of both mobility data collected from mobile phones and a latent time-varying factor that accounts for changes in behavior not captured by mobility data. Since confirmed-case data is unreliable, we infer the model's parameters conditioned on deaths data. We replace the exponential-waiting-time assumption of classic compartmental models with Erlang distributions, which allows for a more realistic model of the long lag between exposure and death. The mobility data gives us a leading indicator that can quickly detect changes in the pandemic's local growth rate and forecast changes in death rates weeks ahead of time. This is an analysis of observational data, so any causal interpretations of the model's inferences should be treated as suggestive at best; nonetheless, the model's inferred relationship between different kinds of trips and the infection rate do suggest some possible hypotheses about what kinds of activities might contribute most to COVID-19's spread.



Citations (34)


... We demonstrate the scalability and effectiveness of the proposed MCMC sampler through a synthetic data experiment. Further, an application to Montreal's Ecocounter bicycle data shows its comparable performance against standard Bayesian GLM, BKTR (Lanthier et al., 2023), and BayesNF (Saad et al., 2024). In summary, the key contributions of this paper include: ...

Reference:

Scalable Spatiotemporal Modeling for Bicycle Count Prediction
Scalable spatiotemporal prediction with Bayesian neural fields

... These models are explicitly trained to interpret and follow natural language instructions, enabling them to adapt flexibly to diverse tasks while maintaining coherent reasoning (Mishra et al. 2022;Wei et al. 2022). Models such as GPT, Claude, and Llama can now engage in sophisticated tasks including analysis, summarization, and complex problem-solving-activities that previously required human expertise (Kojima et al. 2022;Srivastava et al. 2023). This capability to follow explicit instructions while drawing on broad knowledge has transformed these models from simple text generators into versatile tools for knowledge work, opening new possibilities for automating complex cognitive tasks that were previously considered beyond the reach of computational approaches. ...

Beyond the Imitation Game: Quantifying and extrapolating the capabilities of language models
  • Citing Preprint
  • June 2022

... Note that we did not correct for these sources of noise when assessing their detrimental effects on epidemic controllability. While several studies have focussed on estimating and compensating for under-reporting [11] and reporting delays [40,41,42], these approaches often require additional knowledge about the reporting process or orthogonal data sources [43]. It is often the case that these are not available or only become available later in epidemics so we preferred to characterise performance under the more practical scenario that little else is known about the epidemic than its time series of cases. ...

Statistical Deconvolution for Inference of Infection Time Series

Epidemiology

... Furthermore, the use of COVID-19 deaths or hospitalizations inevitably introduces a longer lag after infection so that 'real-time estimation' reflects transmission patterns as long as a month previously. Differentiating genuine reductions in transmission intensity from biases similar to censoring, due to the delayed presentation of cases that were recently infected, poses further difficulty for which statistical deconvolution approaches remain underdeveloped (Miller, Hannah, et al., 2020). As discussed below, this challenge limits the value of transmission intensity estimates in informing real-time public health policy changes. ...

Statistical deconvolution for inference of infection time series

... Music genre classification, a fundamental task in music information retrieval, remains of paramount importance, [1], [2]. It serves as a means for humans to categorize and describe various music collections, [3], offering applications in music recommendation, [4], and music curation, [5], among others. Traditional approaches to music genre classification have often focused solely on either the textual (lyrics) or audio modalities, [2], with some studies like that of [6], concentrating on audio modality, while others, such as [1], emphasizing lyrics/textual modality. ...

Large-Scale Weakly-Supervised Content Embeddings for Music Recommendation and Tagging
  • Citing Conference Paper
  • May 2020

... In all cases, the initial step size is set as 1, the target acceptance rate is 0.6, and the number of burnin steps is 2 000. Particularly, the No-U-Turn is implemented by using the Tensorflow Probability package [31] . ...

tfp.mcmc: Modern Markov Chain Monte Carlo Tools Built for Modern Hardware

... Target speech extraction (TSE) methods aim at isolating a specific speaker's voice from a multi-talker mixture. This process uses cues associated with the desired speaker, such as a prerecorded enrollment speech that highlights the speaker's vocal characteristics [1,2], a spatial cue indicating the direction from which the speaker is speaking [3], or video input capturing lip movements [4,5]. However, in real-world scenarios, these preregistered cues can vary significantly, raise privacy concerns or may even be absent, thereby limiting the practicality and effectiveness of TSE systems. ...

VoiceFilter: Targeted Voice Separation by Speaker-Conditioned Spectrogram Masking
  • Citing Conference Paper
  • September 2019

... Our SE deep model is no longer forced to explicitly estimate a single phase solution, overcoming the above-mentioned problems, such as sign indeterminacy, by focusing on the generation of phase spectrogram consistent with its magnitude. While a loss function based on the consistency constraint is used in recent SE models [25], [27] through consistency projection [19], [33], a key difference is that these works still primarily rely on conventional phase loss functions, such as anti-wrapping or complex-domain loss, with the consistency-based loss treated merely as an auxiliary loss. In other words, the main objective remains the direct estimation of the original phase, leaving the challenges of phase estimation unresolved. ...

Differentiable Consistency Constraints for Improved Deep Speech Enhancement
  • Citing Conference Paper
  • May 2019

... Spectral mapping methods optimize the DNN model to map the T-F representation (or spectral feature) of the noisy speech to the clean T-F representation. The most commonly used spectral features for spectral mapping include log-power spectrum (LPS) [27], magnitude spectrum (MS) [11], [28]- [30], and complex spectrum (CS) [31], [32]. ...

Exploring Tradeoffs in Models for Low-Latency Speech Enhancement
  • Citing Conference Paper
  • September 2018