Implicit Neural Differential Model for Spatiotemporal Dynamics

Deepak Akhare Pan Du Tengfei Luo Jian-Xun Wang Department of Aerospace and Mechanical Engineering, University of Notre Dame, Notre Dame, IN, USA Sibley School of Mechanical and Aerospace Engineering, Cornell University, Ithaca, NY, USA
Abstract

Hybrid neural–physics modeling frameworks through differentiable programming have emerged as powerful tools in scientific machine learning, enabling the integration of known physics with data-driven learning to improve prediction accuracy and generalizability. However, most existing hybrid frameworks rely on explicit recurrent formulations, which suffer from numerical instability and error accumulation during long-horizon forecasting. In this work, we introduce Im-PiNDiff, a novel implicit physics-integrated neural differentiable solver for stable and accurate modeling of spatiotemporal dynamics. Inspired by deep equilibrium models, Im-PiNDiff advances the state using implicit fixed-point layers, enabling robust long-term simulation while remaining fully end-to-end differentiable. To enable scalable training, we introduce a hybrid gradient propagation strategy that integrates adjoint-state methods with reverse-mode automatic differentiation. This approach eliminates the need to store intermediate solver states and decouples memory complexity from the number of solver iterations, significantly reducing training overhead. We further incorporate checkpointing techniques to manage memory in long-horizon rollouts. Numerical experiments on various spatiotemporal PDE systems, including advection–diffusion processes, Burgers’ dynamics, and multi-physics chemical vapor infiltration processes, demonstrate that Im-PiNDiff achieves superior predictive performance, enhanced numerical stability, and substantial reductions in memory and runtime cost relative to explicit and naive implicit baselines. This work provides a principled, efficient, and scalable framework for hybrid neural–physics modeling.

keywords:
, Differentiable programming , Implicit neural networks , Scientific Machine Learning , Hybrid model , Grey-box model
journal: Elsevier

1 Introduction

Computational science is experiencing a transformative shift, driven by advancements in numerical techniques, artificial intelligence (AI), and the growing availability of extensive datasets. At the core of this transformation lies the scientific machine learning (SciML), a rapidly evolving field that synergistically integrates physics-based modeling with modern machine learning (ML). By combining data-driven insights with fundamental physical principles, SciML offers unprecedented opportunities to accurately model complex systems, enhance predictive capabilities, and optimize computational workflows across a wide range of scientific and engineering applications.

Notable methodologies underpinning SciML include physics-informed neural networks (PINNs) [1, 2, 3], neural operators [4, 5, 6], equation discovery techniques [7, 8, 9], and hybrid neural-physics models [10, 11, 12, 13, 14, 15, 16]. Among these, hybrid neural-physics models have attracted significant attention due to their explicit incorporation of domain knowledge into learning architecture, addressing the key limitations of purely data-driven approaches, such as limited extrapolation and generalization capabilities. This integration ensures consistency with known physical laws meanwhile enabling us to capture uncharacterized system dynamics, thus balancing flexibility with reliability. Historically, however, most hybrid frameworks have employed weak coupling strategies, wherein ML models are trained offline and subsequently embedded into conventional numerical solvers [17, 18, 19]. Such approaches, while prevalent in turbulence modeling [20, 21, 22] and atmospheric simulations [23, 24], are fundamentally limited. Their reliance on separately trained components prevents end-to-end optimization, restricting their robustness, generalization, and adaptability to complex, unseen scenarios.

Recognizing these challenges, recent trends advocate strongly integrated frameworks that enable end-to-end differentiable hybridization, leveraging indirect and sparse observational data. This integration is made possible by differentiable programming (DP) [25], which facilitates joint optimization of ML models and numerical solvers. Recent advances in differentiable physics and hybrid neural-physics models have demonstrated significant potential across various scientific domains [26, 27, 28, 29, 30, 10, 31, 13, 11, 12, 32]. For example, Kochkov et al. [10] utilized convolutional neural networks (CNNs) to accelerate differentiable computational fluid dynamics (CFD) solvers, while Huang et al. [30] embedded neural networks within a differentiable finite element solver to learn constitutive relations from indirect observations. Further advances by Wang and coworkers introduced a physics-integrated neural differentiable (PiNDiff) modeling framework, unifying neural operators with numerical PDE solvers to achieve enhanced generalization and accuracy [13, 11, 12, 32, 16].

Despite their considerable potential, current PiNDiff frameworks predominantly utilize explicit, auto-regressive recurrent architectures. Such explicit recurrent structures suffer from inherent numerical instability, error accumulation, and deteriorating performance in long-term predictions, particularly for stiff or chaotic systems, thereby limiting their practical applicability. Inspired by classical numerical analysis, where implicit methods offer superior stability properties, this paper proposes an innovative implicit neural differential model (Im-PiNDiff) for robust spatiotemporal dynamics prediction. By employing implicit neural network layers, our framework mitigates error accumulation and significantly enhances numerical stability and accuracy, enabling reliable long-term simulations.

However, adopting implicit neural architectures within differentiable frameworks introduces considerable computational and memory challenges. This is primarily due to the requirement of automatic differentiation (AD), which involves storing intermediate activations, computational graphs, and input data as buffers during forward propagation to compute gradients [33, 34]. These memory requirements grow exponentially in implicit learning architectures, involving bilevel optimization, iterative solvers, and extended simulations, frequently leading to prohibitive training times [35, 32, 25]. To address these computational hurdles, we introduce a hybrid training strategy that combines adjoint state methods with reverse-mode AD. The adjoint method decouple memory requirements from iterative solver iterations, significantly reducing computational overhead and memory usage. Our approach employs adjoint-based methods to compute and propagate gradients over the implicit layers while utilizing reverse-mode AD for explicit model components. Further, we employ strategic checkpointing techniques to optimize memory usage, ensuring scalability and practical feasibility for large-scale, complex problems. In summary, the key contributions of this work include: (1) a novel implicit PiNDiff framework integrating implicit neural architectures, differentiable numerical PDEs, and conditional neural field, enabling stable and accurate long-term predictions of complex spatiotemporal dynamics; (2) a hybrid gradient computation strategy that leverages adjoint-state methods and reverse-mode AD, significantly improving computational efficiency and reducing memory overhead; (3) numerical validation demonstrating improved performance, stability, and computational feasibility of the proposed Im-PiNDiff framework. Together, these innovations represent a significant step toward enabling the efficient and scalable application of implicit PiNDiff frameworks.

The rest of this paper is organized as follows: Section 2 details the proposed methodology, including the mathematical formulation, implicit neural differentiable model, and hybrid training strategies. Section 3 presents numerical experiments demonstrating the framework’s performance across a range of applications. Finally, Section 4 summarizes the contributions and outlines potential future research directions.

2 Methodology

2.1 Problem formulation

Most fundamental physical laws governing phenomena across diverse scientific and engineering disciplines, such as fluid dynamics, solid mechanics, heat transfer, electromagnetism, and quantum mechanics, are naturally expressed in the mathematical form of partial differential equations (PDEs). In practice, however, the exact forms of these PDEs often contain unknown or uncertain components due to incomplete understanding of underlying physics or inherent modeling limitations. Such scenarios can be represented by generic PDEs,

ϕt=[𝒦(ϕ(𝐱~);𝝀𝒦,𝝀𝒰),𝒰(ϕ(𝐱~);𝝀𝒦,𝝀𝒰)],italic-ϕ𝑡𝒦italic-ϕ~𝐱subscript𝝀𝒦subscript𝝀𝒰𝒰italic-ϕ~𝐱subscript𝝀𝒦subscript𝝀𝒰\displaystyle\frac{\partial\phi}{\partial t}=\mathscr{F}\big{[}\mathcal{K}(% \phi(\tilde{\mathbf{x}});\boldsymbol{\lambda}_{\mathcal{K}},\boldsymbol{% \lambda}_{\mathcal{U}}),\mathcal{U}(\phi(\tilde{\mathbf{x}});\boldsymbol{% \lambda}_{\mathcal{K}},\boldsymbol{\lambda}_{\mathcal{U}})\big{]},\ \ \ \ divide start_ARG ∂ italic_ϕ end_ARG start_ARG ∂ italic_t end_ARG = script_F [ caligraphic_K ( italic_ϕ ( over~ start_ARG bold_x end_ARG ) ; bold_italic_λ start_POSTSUBSCRIPT caligraphic_K end_POSTSUBSCRIPT , bold_italic_λ start_POSTSUBSCRIPT caligraphic_U end_POSTSUBSCRIPT ) , caligraphic_U ( italic_ϕ ( over~ start_ARG bold_x end_ARG ) ; bold_italic_λ start_POSTSUBSCRIPT caligraphic_K end_POSTSUBSCRIPT , bold_italic_λ start_POSTSUBSCRIPT caligraphic_U end_POSTSUBSCRIPT ) ] , 𝐱~Ωp,t,~𝐱subscriptΩ𝑝𝑡\displaystyle\tilde{\mathbf{x}}\in\Omega_{p,t},over~ start_ARG bold_x end_ARG ∈ roman_Ω start_POSTSUBSCRIPT italic_p , italic_t end_POSTSUBSCRIPT , (1a)
𝒞(ϕ(𝐱~);𝝀𝒦,𝝀𝒰)=0,𝒞italic-ϕ~𝐱subscript𝝀𝒦subscript𝝀𝒰0\displaystyle\mathcal{BC}(\phi(\tilde{\mathbf{x}});\boldsymbol{\lambda}_{% \mathcal{K}},\boldsymbol{\lambda}_{\mathcal{U}})=0,\ \ \ \ caligraphic_B caligraphic_C ( italic_ϕ ( over~ start_ARG bold_x end_ARG ) ; bold_italic_λ start_POSTSUBSCRIPT caligraphic_K end_POSTSUBSCRIPT , bold_italic_λ start_POSTSUBSCRIPT caligraphic_U end_POSTSUBSCRIPT ) = 0 , 𝐱~Ωp,t,~𝐱subscriptΩ𝑝𝑡\displaystyle\tilde{\mathbf{x}}\in\partial\Omega_{p,t},over~ start_ARG bold_x end_ARG ∈ ∂ roman_Ω start_POSTSUBSCRIPT italic_p , italic_t end_POSTSUBSCRIPT , (1b)
𝒞(ϕ(𝐱~);𝝀𝒦,𝝀𝒰)=0,𝒞italic-ϕ~𝐱subscript𝝀𝒦subscript𝝀𝒰0\displaystyle\mathcal{IC}(\phi(\tilde{\mathbf{x}});\boldsymbol{\lambda}_{% \mathcal{K}},\boldsymbol{\lambda}_{\mathcal{U}})=0,\ \ \ \ caligraphic_I caligraphic_C ( italic_ϕ ( over~ start_ARG bold_x end_ARG ) ; bold_italic_λ start_POSTSUBSCRIPT caligraphic_K end_POSTSUBSCRIPT , bold_italic_λ start_POSTSUBSCRIPT caligraphic_U end_POSTSUBSCRIPT ) = 0 , 𝐱~Ωp,t=0,~𝐱subscriptΩ𝑝𝑡0\displaystyle\tilde{\mathbf{x}}\in\Omega_{p,t=0},over~ start_ARG bold_x end_ARG ∈ roman_Ω start_POSTSUBSCRIPT italic_p , italic_t = 0 end_POSTSUBSCRIPT , (1c)

where nonlinear functions 𝒦()𝒦\mathcal{K}(\cdot)caligraphic_K ( ⋅ ) and 𝒰()𝒰\mathcal{U}(\cdot)caligraphic_U ( ⋅ ) represent the known and unknown components of the PDEs for state variable ϕitalic-ϕ\phiitalic_ϕ, coupled via the nonlinear functional ()\mathscr{F}(\cdot)script_F ( ⋅ ). The initial and boundary conditions are abstractly defined by the PDE operators 𝒞𝒞\mathcal{IC}caligraphic_I caligraphic_C and 𝒞𝒞\mathcal{BC}caligraphic_B caligraphic_C, respectively. These functions rely on a set of physical parameters, with the known ones denoted by 𝝀𝒦subscript𝝀𝒦\boldsymbol{\lambda}_{\mathcal{K}}bold_italic_λ start_POSTSUBSCRIPT caligraphic_K end_POSTSUBSCRIPT and uncertain ones by 𝝀𝒰subscript𝝀𝒰\boldsymbol{\lambda}_{\mathcal{U}}bold_italic_λ start_POSTSUBSCRIPT caligraphic_U end_POSTSUBSCRIPT, respectively. The spatial and temporal coordinates are denoted as 𝐱~={x,t}~𝐱x𝑡\tilde{\mathbf{x}}=\{\textbf{x},t\}over~ start_ARG bold_x end_ARG = { x , italic_t }, whit physical domain ΩpsubscriptΩ𝑝\Omega_{p}roman_Ω start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT, boundary ΩpsubscriptΩ𝑝\partial\Omega_{p}∂ roman_Ω start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT, and time domain [0,T]0𝑇[0,T][ 0 , italic_T ], resulting in a spatiotemporal domain Ωp,tΩp×[0,T]subscriptΩ𝑝𝑡subscriptΩ𝑝0𝑇\Omega_{p,t}\triangleq\Omega_{p}\times[0,T]roman_Ω start_POSTSUBSCRIPT italic_p , italic_t end_POSTSUBSCRIPT ≜ roman_Ω start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT × [ 0 , italic_T ].

Due to incomplete PDE formulations, hybrid neural models based on DP, e.g., PiNDiff, integrate deep neural networks (DNNs) to approximate unknown PDE components/operators or enhance known components through learnable operators. The integrated neural network parameters 𝜽nnsubscript𝜽𝑛𝑛\boldsymbol{\theta}_{nn}bold_italic_θ start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT and uncertain physical parameters 𝝀𝒰subscript𝝀𝒰\boldsymbol{\lambda}_{\mathcal{U}}bold_italic_λ start_POSTSUBSCRIPT caligraphic_U end_POSTSUBSCRIPT form a unified set of trainable parameters 𝜽=[𝜽nn,𝝀𝒰]T𝜽superscriptsubscript𝜽𝑛𝑛subscript𝝀𝒰𝑇\boldsymbol{\theta}=[\boldsymbol{\theta}_{nn},\boldsymbol{\lambda}_{\mathcal{U% }}]^{T}bold_italic_θ = [ bold_italic_θ start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT , bold_italic_λ start_POSTSUBSCRIPT caligraphic_U end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT, optimized concurrently as part of a unified network, enabled by DP. The training of PiNDiff model over dataset 𝒟={𝐱~i,ϕi}i=1N𝒟subscriptsuperscriptsubscript~𝐱𝑖subscriptitalic-ϕ𝑖𝑁𝑖1\mathcal{D}=\{\tilde{\mathbf{x}}_{i},{\phi}_{i}\}^{N}_{i=1}caligraphic_D = { over~ start_ARG bold_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_ϕ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT is formulated as a PDE-constrained optimization problem,

min𝜽subscript𝜽\displaystyle\min_{\boldsymbol{\theta}}\quadroman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT J(ϕ𝜽(𝐱~),𝒟;𝜽)𝐽subscriptitalic-ϕ𝜽~𝐱𝒟𝜽\displaystyle J(\phi_{\boldsymbol{\theta}}(\tilde{\mathbf{x}}),\mathcal{D};% \boldsymbol{\theta})italic_J ( italic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( over~ start_ARG bold_x end_ARG ) , caligraphic_D ; bold_italic_θ ) (2a)
s.t. ϕ𝜽(𝐱~)t=nn[𝒦(ϕ𝜽(𝐱~);𝝀𝒦,𝝀𝒰),𝒰nn(ϕ𝜽(𝐱~);𝝀𝒦,𝝀𝒰,𝜽nn);𝜽nn],𝐱~Ωp,tformulae-sequencesubscriptitalic-ϕ𝜽~𝐱𝑡subscript𝑛𝑛𝒦subscriptitalic-ϕ𝜽~𝐱subscript𝝀𝒦subscript𝝀𝒰subscript𝒰𝑛𝑛subscriptitalic-ϕ𝜽~𝐱subscript𝝀𝒦subscript𝝀𝒰subscript𝜽𝑛𝑛subscript𝜽𝑛𝑛~𝐱subscriptΩ𝑝𝑡\displaystyle\frac{\partial\phi_{\boldsymbol{\theta}}(\tilde{\mathbf{x}})}{% \partial t}=\mathscr{F}_{nn}\big{[}\mathcal{K}(\phi_{\boldsymbol{\theta}}(% \tilde{\mathbf{x}});\boldsymbol{\lambda}_{\mathcal{K}},\boldsymbol{\lambda}_{% \mathcal{U}}),\mathcal{U}_{nn}(\phi_{\boldsymbol{\theta}}(\tilde{\mathbf{x}});% \boldsymbol{\lambda}_{\mathcal{K}},\boldsymbol{\lambda}_{\mathcal{U}},% \boldsymbol{\theta}_{nn});\boldsymbol{\theta}_{nn}\big{]},\ \ \ \ \tilde{% \mathbf{x}}\in\Omega_{p,t}divide start_ARG ∂ italic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( over~ start_ARG bold_x end_ARG ) end_ARG start_ARG ∂ italic_t end_ARG = script_F start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT [ caligraphic_K ( italic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( over~ start_ARG bold_x end_ARG ) ; bold_italic_λ start_POSTSUBSCRIPT caligraphic_K end_POSTSUBSCRIPT , bold_italic_λ start_POSTSUBSCRIPT caligraphic_U end_POSTSUBSCRIPT ) , caligraphic_U start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT ( italic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( over~ start_ARG bold_x end_ARG ) ; bold_italic_λ start_POSTSUBSCRIPT caligraphic_K end_POSTSUBSCRIPT , bold_italic_λ start_POSTSUBSCRIPT caligraphic_U end_POSTSUBSCRIPT , bold_italic_θ start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT ) ; bold_italic_θ start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT ] , over~ start_ARG bold_x end_ARG ∈ roman_Ω start_POSTSUBSCRIPT italic_p , italic_t end_POSTSUBSCRIPT (2b)
𝒞(ϕ𝜽(𝐱~);𝝀𝒦,𝝀𝒰)=0,𝐱~Ωp,tformulae-sequence𝒞subscriptitalic-ϕ𝜽~𝐱subscript𝝀𝒦subscript𝝀𝒰0~𝐱subscriptΩ𝑝𝑡\displaystyle\mathcal{BC}(\phi_{\boldsymbol{\theta}}(\tilde{\mathbf{x}});% \boldsymbol{\lambda}_{\mathcal{K}},\boldsymbol{\lambda}_{\mathcal{U}})=0,\ \ % \ \ \tilde{\mathbf{x}}\in\partial\Omega_{p,t}caligraphic_B caligraphic_C ( italic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( over~ start_ARG bold_x end_ARG ) ; bold_italic_λ start_POSTSUBSCRIPT caligraphic_K end_POSTSUBSCRIPT , bold_italic_λ start_POSTSUBSCRIPT caligraphic_U end_POSTSUBSCRIPT ) = 0 , over~ start_ARG bold_x end_ARG ∈ ∂ roman_Ω start_POSTSUBSCRIPT italic_p , italic_t end_POSTSUBSCRIPT (2c)
𝒞(ϕ𝜽(𝐱~);𝝀𝒦,𝝀𝒰)=0,𝐱~Ωp,t=0formulae-sequence𝒞subscriptitalic-ϕ𝜽~𝐱subscript𝝀𝒦subscript𝝀𝒰0~𝐱subscriptΩ𝑝𝑡0\displaystyle\mathcal{IC}(\phi_{\boldsymbol{\theta}}(\tilde{\mathbf{x}});% \boldsymbol{\lambda}_{\mathcal{K}},\boldsymbol{\lambda}_{\mathcal{U}})=0,\ \ % \ \ \tilde{\mathbf{x}}\in\Omega_{p,t=0}caligraphic_I caligraphic_C ( italic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( over~ start_ARG bold_x end_ARG ) ; bold_italic_λ start_POSTSUBSCRIPT caligraphic_K end_POSTSUBSCRIPT , bold_italic_λ start_POSTSUBSCRIPT caligraphic_U end_POSTSUBSCRIPT ) = 0 , over~ start_ARG bold_x end_ARG ∈ roman_Ω start_POSTSUBSCRIPT italic_p , italic_t = 0 end_POSTSUBSCRIPT (2d)

where J𝐽Jitalic_J is the objective function. Typically, existing PiNDiff models rely on explicit recurrent formulations, wherein solutions at each time step explicitly depend on the previous state. Although straightforward, these auto-regressive recurrent approaches commonly suffer from instability and error accumulation, especially during long-horizon predictions.

To address these limitations, we introduce an implicit PiNDiff model, termed Im-PiNDiff, inspired by implicit numerical schemes known for their superior numerical stability and robustness. Specifically, we replace explicit recurrent layers with an implicit layer formulation, analogous to the deep equilibrium model (DEQ) introduced by Bai et al. [36]. DEQ employs implicit neural computations through solving a fixed-point equilibrium, enabling effectively infinite-depth representations without explicit unrolling of iterative layers. Similarly, in the context of hybrid neural-physics modeling, Im-PiNDiff conceptualizes each time step as solving an implicit nonlinear equation rather than explicit forward stepping. Training the Im-PiNDiff model involves explicitly embedding PDE constraints, resulting in a PDE-constrained bilevel optimization problem formulated mathematically as follows:

min𝜽J(ϕ𝜽(𝐱,t),𝒟;𝜽),Ωp,t,subscript𝜽𝐽subscriptitalic-ϕ𝜽𝐱𝑡𝒟𝜽subscriptΩ𝑝𝑡\displaystyle\min_{\boldsymbol{\theta}}\quad J(\phi_{\boldsymbol{\theta}}(% \mathbf{x},t),\mathcal{D};\boldsymbol{\theta}),\Omega_{p,t},roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT italic_J ( italic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_x , italic_t ) , caligraphic_D ; bold_italic_θ ) , roman_Ω start_POSTSUBSCRIPT italic_p , italic_t end_POSTSUBSCRIPT , (3a)
s.t. ϕ𝜽(𝐱,ti+1)=ϕ𝜽(𝐱,ti)+titi+1nn[ϕ𝜽(𝐱,t);𝜽]𝑑t,i=0,,Tformulae-sequencesubscriptitalic-ϕ𝜽𝐱subscript𝑡𝑖1subscriptitalic-ϕ𝜽𝐱subscript𝑡𝑖superscriptsubscriptsubscript𝑡𝑖subscript𝑡𝑖1subscript𝑛𝑛subscriptitalic-ϕ𝜽𝐱𝑡𝜽differential-d𝑡𝑖0𝑇\displaystyle\phi_{\boldsymbol{\theta}}(\mathbf{x},t_{i+1})=\phi_{\boldsymbol{% \theta}}(\mathbf{x},t_{i})+\int_{t_{i}}^{t_{i+1}}\mathscr{F}_{nn}\big{[}\phi_{% \boldsymbol{\theta}}(\mathbf{x},t);\boldsymbol{\theta}\big{]}dt,\quad i=0,% \dots,Titalic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_x , italic_t start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) = italic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_x , italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + ∫ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT script_F start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT [ italic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_x , italic_t ) ; bold_italic_θ ] italic_d italic_t , italic_i = 0 , … , italic_T (3b)
ϕ𝜽(𝐱BC,ti)=𝒞(ϕ𝜽(𝐱,ti);𝜽),i=0,,Tformulae-sequencesubscriptitalic-ϕ𝜽subscript𝐱𝐵𝐶subscript𝑡𝑖𝒞subscriptitalic-ϕ𝜽𝐱subscript𝑡𝑖𝜽𝑖0𝑇\displaystyle\phi_{\boldsymbol{\theta}}(\mathbf{x}_{BC},t_{i})=\mathcal{BC}(% \phi_{\boldsymbol{\theta}}(\mathbf{x},t_{i});\boldsymbol{\theta}),\quad i=0,% \dots,Titalic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_x start_POSTSUBSCRIPT italic_B italic_C end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = caligraphic_B caligraphic_C ( italic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_x , italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ; bold_italic_θ ) , italic_i = 0 , … , italic_T (3c)
ϕ𝜽(𝐱,t0)=𝒞(𝜽),𝐱Ωp,formulae-sequencesubscriptitalic-ϕ𝜽𝐱subscript𝑡0𝒞𝜽𝐱subscriptΩ𝑝\displaystyle\phi_{\boldsymbol{\theta}}(\mathbf{x},t_{0})=\mathcal{IC}(% \boldsymbol{\theta}),\ \ \ \mathbf{x}\in\Omega_{p},italic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_x , italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_I caligraphic_C ( bold_italic_θ ) , bold_x ∈ roman_Ω start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT , (3d)

where the outer-level optimization aims to minimize the discrepancy between model predictions and observed data (the loss function), while the inner-level optimization involves solving a nonlinear equilibrium equation to determine the solution ϕ𝜽(𝐱,ti+1)subscriptitalic-ϕ𝜽𝐱subscript𝑡𝑖1\phi_{\boldsymbol{\theta}}(\mathbf{x},t_{i+1})italic_ϕ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_x , italic_t start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT ) at each rollout step ti+1subscript𝑡𝑖1t_{i+1}italic_t start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT. A schematics illustrating the Im-PiNDiff and its training strategy is presented in Fig. 1

Refer to caption
Figure 1: Schematic of the Im-PiNDiff framework, where temporal states 𝚽tsubscript𝚽𝑡\boldsymbol{\Phi}_{t}bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT are advanced via implicit updates incorporating known and learned physics.

As shown in Fig. 1, propagating gradients through bilevel optimization poses significant challenges since standard AD frameworks typically require static computational graphs. Naively implementing unrolled differentiation with a fixed-iteration inner solver is computationally expensive and memory-intensive, severely limiting scalability. To circumvent this, we propose a hybrid training strategy integrating the adjoint-state implicit differentiation method with reverse-mode AD. Specifically, the adjoint-state approach propagates gradients through implicit equilibrium constraints, subsequently combined with reverse-mode AD gradients via the chain rule. A key benefit of adjoint-based gradient propagation is that it eliminates the necessity for storing intermediate computational nodes during forward propagation in implicit layers, significantly reducing memory usage. Additionally, we leverage checkpointing strategies to enhance memory efficiency during training. By combining adjoint methods with strategic checkpointing, our hybrid training strategy effectively balances computational performance and memory efficiency, enabling practical and scalable training of Im-PiNDiff models for large-scale scientific modeling problems.

2.2 Adjoint-based hybrid AD for efficient gradient propagation

Adjoint-based gradient computation methods have been increasingly adopted to enhance the efficiency and scalability of neural network training, especially for architectures involving implicitly defined operations [37, 38, 39, 40, 36]. In this study, we present an adjoint-based implicit differentiation approach for efficient Im-PiNDiff training, which require solving a nonlinear inner optimization during forward propagation.

The fundamental distinction between explicit and implicit layers in PiNDiff models is illustrated in Fig. 2.

Refer to caption
Figure 2: Comparison between the schematics of PiNDiff and Im-PiNDiff layers.

We formally define an implicit layer by:

fim(𝚽t;𝚽t1,𝜽)=0,superscript𝑓𝑖𝑚subscript𝚽𝑡subscript𝚽𝑡1𝜽0f^{im}\big{(}\boldsymbol{\Phi}_{t};\boldsymbol{\Phi}_{t-1},\boldsymbol{\theta}% \big{)}=0,italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_Φ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , bold_italic_θ ) = 0 , (4)

where 𝚽tnsubscript𝚽𝑡superscript𝑛\boldsymbol{\Phi}_{t}\in\mathbb{R}^{n}bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT is the discrete state vector at the current step t𝑡titalic_t and 𝚽t1subscript𝚽𝑡1\boldsymbol{\Phi}_{t-1}bold_Φ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT represents known previous states, dependent implicitly on parameter 𝜽p𝜽superscript𝑝\boldsymbol{\theta}\in\mathbb{R}^{p}bold_italic_θ ∈ blackboard_R start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT. The nonlinear implicit function fim:n×n×pn:superscript𝑓𝑖𝑚superscript𝑛superscript𝑛superscript𝑝superscript𝑛f^{im}:\mathbb{R}^{n}\times\mathbb{R}^{n}\times\mathbb{R}^{p}\to\mathbb{R}^{n}italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT. Given the typically high dimensionality of the problem, iterative numerical root-finding algorithms are usually employed to solve for 𝚽tsubscript𝚽𝑡\boldsymbol{\Phi}_{t}bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. Thus, the forward pass through the implicit layer can be succinctly expressed as,

𝚽tRootFind(fim(𝚽t;𝚽t1,𝜽)=0).superscriptsubscript𝚽𝑡RootFindsuperscript𝑓𝑖𝑚subscript𝚽𝑡subscript𝚽𝑡1𝜽0\boldsymbol{\Phi}_{t}^{*}\leftarrow\text{RootFind}\Big{(}f^{im}\big{(}% \boldsymbol{\Phi}_{t};\boldsymbol{\Phi}_{t-1},\boldsymbol{\theta}\big{)}=0\Big% {)}.bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ← RootFind ( italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_Φ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT , bold_italic_θ ) = 0 ) . (5)

DP with standard AD computes gradients by systematically applying the chain rule to decompose complex computer program into elementary operations. During the forward propagation, DP evaluates and records intermediate variables and their associated computational dependencies, building a computational graph from the inputs to the outputs. In reverse-mode AD, gradients are computed by traversing this computational graph backwards. Specifically, starting from the output, DP computes vector-Jacobian products (VJPs) at each node, sequentially applying the chain rule in reverse order. These VJPs represent the sensitivity of the output with respect to each intermediate variable. For the total loss L(𝜽)𝐿𝜽L(\boldsymbol{\theta})italic_L ( bold_italic_θ ) defined over the entire rollout trajectory of T𝑇Titalic_T steps,

L(𝜽)=t=0Tt(𝚽t;𝒟t)+regulate(𝜽),𝐿𝜽superscriptsubscript𝑡0𝑇subscript𝑡subscript𝚽𝑡subscript𝒟𝑡subscript𝑟𝑒𝑔𝑢𝑙𝑎𝑡𝑒𝜽L(\boldsymbol{\theta})=\sum_{t=0}^{T}\mathcal{L}_{t}(\boldsymbol{\Phi}_{t};% \mathcal{D}_{t})+\mathcal{L}_{regulate}(\boldsymbol{\theta}),italic_L ( bold_italic_θ ) = ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; caligraphic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + caligraphic_L start_POSTSUBSCRIPT italic_r italic_e italic_g italic_u italic_l italic_a italic_t italic_e end_POSTSUBSCRIPT ( bold_italic_θ ) , (6)

the gradient of loss with respect to 𝜽𝜽\boldsymbol{\theta}bold_italic_θ can be computed using the chain rule,

dLd𝜽=L𝜽+t=0TL𝜽|𝚽t,𝑑𝐿𝑑𝜽𝐿𝜽evaluated-atsuperscriptsubscript𝑡0𝑇𝐿𝜽superscriptsubscript𝚽𝑡\frac{dL}{d\boldsymbol{\theta}}=\frac{\partial L}{\partial\boldsymbol{\theta}}% +\sum_{t=0}^{T}\frac{\partial L}{\partial\boldsymbol{\theta}}\bigg{|}_{% \boldsymbol{\Phi}_{t}^{*}},divide start_ARG italic_d italic_L end_ARG start_ARG italic_d bold_italic_θ end_ARG = divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_italic_θ end_ARG + ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_italic_θ end_ARG | start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , (7)

where

L𝜽|𝚽t=L𝚽t+1[𝚽t+1𝜽],evaluated-at𝐿𝜽superscriptsubscript𝚽𝑡𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1𝜽\displaystyle\frac{\partial L}{\partial\boldsymbol{\theta}}\bigg{|}_{% \boldsymbol{\Phi}_{t}^{*}}=\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{% *}}\cdot\Bigg{[}\frac{\partial\boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol% {\theta}}\Bigg{]},divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_italic_θ end_ARG | start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG ] , (8a)
L𝚽t=t𝚽t+L𝚽t+1[𝚽t+1𝚽t],𝐿superscriptsubscript𝚽𝑡subscript𝑡superscriptsubscript𝚽𝑡𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1superscriptsubscript𝚽𝑡\displaystyle\frac{\partial L}{\partial\boldsymbol{\Phi}_{t}^{*}}=\frac{% \partial\mathcal{L}_{t}}{\partial\boldsymbol{\Phi}_{t}^{*}}+\frac{\partial L}{% \partial\boldsymbol{\Phi}_{t+1}^{*}}\cdot\Bigg{[}\frac{\partial\boldsymbol{% \Phi}_{t+1}^{*}}{\partial\boldsymbol{\Phi}_{t}^{*}}\Bigg{]},divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG = divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG + divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ] , (8b)

where L𝚽t+1[𝚽t+1𝜽]𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1𝜽\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\Big{[}\frac{\partial% \boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{\theta}}\Big{]}divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG ] and L𝚽t+1[𝚽t+1𝚽t]𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1superscriptsubscript𝚽𝑡\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\Big{[}\frac{\partial% \boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{\Phi}_{t}^{*}}\Big{]}divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ] are VJPs. Given that the output 𝚽tsubscriptsuperscript𝚽𝑡\boldsymbol{\Phi}^{*}_{t}bold_Φ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT from implicit layer is obtained by solving a nonlinear equation through an iterative numerical solver, the process can be abstractly written as a sequence of intermediate iterates:

{𝚽t[0],𝚽t[1],,𝚽t[K]}={RootFind[k](f𝜽im(𝚽t[0];𝚽t1)=0)}k=0Ksuperscriptsubscript𝚽𝑡delimited-[]0superscriptsubscript𝚽𝑡delimited-[]1superscriptsubscript𝚽𝑡delimited-[]𝐾superscriptsubscriptsuperscriptRootFinddelimited-[]𝑘superscriptsubscript𝑓𝜽𝑖𝑚superscriptsubscript𝚽𝑡delimited-[]0subscript𝚽𝑡10𝑘0𝐾\Big{\{}\boldsymbol{\Phi}_{t}^{[0]},\boldsymbol{\Phi}_{t}^{[1]},\cdots,% \boldsymbol{\Phi}_{t}^{[K]}\Big{\}}=\Big{\{}\text{RootFind}^{[k]}\Big{(}f_{% \boldsymbol{\theta}}^{im}\big{(}\boldsymbol{\Phi}_{t}^{[0]};\boldsymbol{\Phi}_% {t-1}\big{)}=0\Big{)}\Big{\}}_{k=0}^{K}{ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 0 ] end_POSTSUPERSCRIPT , bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 1 ] end_POSTSUPERSCRIPT , ⋯ , bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ italic_K ] end_POSTSUPERSCRIPT } = { RootFind start_POSTSUPERSCRIPT [ italic_k ] end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 0 ] end_POSTSUPERSCRIPT ; bold_Φ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ) = 0 ) } start_POSTSUBSCRIPT italic_k = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT (9)

where K𝐾Kitalic_K denotes the number of iterations needed to converge to the equilibrium solution 𝚽t𝚽t[K],K1formulae-sequencesubscriptsuperscript𝚽𝑡superscriptsubscript𝚽𝑡delimited-[]𝐾much-greater-than𝐾1\boldsymbol{\Phi}^{*}_{t}\approx\boldsymbol{\Phi}_{t}^{[K]},K\gg 1bold_Φ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≈ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ italic_K ] end_POSTSUPERSCRIPT , italic_K ≫ 1. If one directly applies standard reverse-mode AD to propagate gradients through this implicit layer, the AD engine must record and retain the full computational graph of all intermediate iterates 𝚽t[k]subscriptsuperscript𝚽delimited-[]𝑘𝑡\boldsymbol{\Phi}^{[k]}_{t}bold_Φ start_POSTSUPERSCRIPT [ italic_k ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to compute VJPs. This results in substantial memory overhead, especially when the number of iterations or the state dimension n𝑛nitalic_n is large.

To address this limitation, we present a hybrid gradient computation strategy using the discrete adjoint-state method, which decouples gradient propagation from the forward iteration history. Instead of tracing all internal solver steps, the adjoint method efficiently computes the required VJP by solving a single linear system, avoiding the need to store intermediate iterates of the root-finding process. Concretely, for each time step tt+1𝑡𝑡1t\to t+1italic_t → italic_t + 1, the Jacobians of 𝚽t+1superscriptsubscript𝚽𝑡1\boldsymbol{\Phi}_{t+1}^{*}bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT with respect to 𝚽tsuperscriptsubscript𝚽𝑡\boldsymbol{\Phi}_{t}^{*}bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT and 𝜽𝜽\boldsymbol{\theta}bold_italic_θ are computed using the implicit function theorem [35],

𝚽t+1𝜽=(fim𝚽t+1)1fim𝜽,superscriptsubscript𝚽𝑡1𝜽superscriptsuperscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡11superscript𝑓𝑖𝑚𝜽\displaystyle\frac{\partial\boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{% \theta}}=-\left(\frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t+1}^{*}}% \right)^{-1}\cdot\frac{\partial f^{im}}{\partial\boldsymbol{\theta}},divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG = - ( divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⋅ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG , (10a)
𝚽t+1𝚽t=(fim𝚽t+1)1fim𝚽t.superscriptsubscript𝚽𝑡1superscriptsubscript𝚽𝑡superscriptsuperscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡11superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡\displaystyle\frac{\partial\boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{% \Phi}_{t}^{*}}=-\left(\frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t+1}^{% *}}\right)^{-1}\cdot\frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t}^{*}}.divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG = - ( divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⋅ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG . (10b)

Now we define the adjoint vector 𝐰t+1T1×nsubscriptsuperscript𝐰𝑇𝑡1superscript1𝑛\mathbf{w}^{T}_{t+1}\in\mathbb{R}^{1\times n}bold_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_n end_POSTSUPERSCRIPT as,

𝐰T=L𝚽t+1(fim𝚽t+1)1,superscript𝐰𝑇𝐿superscriptsubscript𝚽𝑡1superscriptsuperscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡11\mathbf{w}^{T}=-\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\cdot% \left(\frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\right)^{-1},bold_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT = - divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ ( divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT , (11)

and then the VJPs are expressed as,

L𝚽t+1[𝚽t+1𝜽]=𝐰t+1Tfim𝜽,𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1𝜽subscriptsuperscript𝐰𝑇𝑡1superscript𝑓𝑖𝑚𝜽\displaystyle\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\cdot\Bigg{% [}\frac{\partial\boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{\theta}}\Bigg% {]}=\mathbf{w}^{T}_{t+1}\cdot\frac{\partial f^{im}}{\partial\boldsymbol{\theta% }},divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG ] = bold_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ⋅ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG , (12a)
L𝚽t+1[𝚽t+1𝚽t]=𝐰t+1Tfim𝚽t,𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1superscriptsubscript𝚽𝑡subscriptsuperscript𝐰𝑇𝑡1superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡\displaystyle\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\cdot\Bigg{% [}\frac{\partial\boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{\Phi}_{t}^{*}% }\Bigg{]}=\mathbf{w}^{T}_{t+1}\cdot\frac{\partial f^{im}}{\partial\boldsymbol{% \Phi}_{t}^{*}},divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ] = bold_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ⋅ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG , (12b)

where 𝐰Tsuperscript𝐰𝑇\mathbf{w}^{T}bold_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT is obtained by solving the linear system,

wT[fim𝚽t+1]=L𝚽t+1.superscriptw𝑇delimited-[]superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡1𝐿superscriptsubscript𝚽𝑡1\textbf{w}^{T}\bigg{[}\frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t+1}^{% *}}\bigg{]}=-\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}.w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT [ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ] = - divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG . (13)

This linear equation is solved efficiently using iterative numerical linear solvers (such as GMRES or conjugate gradient methods). Instead of explicitly forming the Jacobian matrix fim𝚽tn×nsuperscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡superscript𝑛𝑛\frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t}^{*}}\in\mathbb{R}^{n% \times n}divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_n end_POSTSUPERSCRIPT and fim𝜽n×psuperscript𝑓𝑖𝑚𝜽superscript𝑛𝑝\frac{\partial f^{im}}{\partial\boldsymbol{\theta}}\in\mathbb{R}^{n\times p}divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_p end_POSTSUPERSCRIPT, we utilize VJP functions provided by AD tools to directly obtain 𝐰Tfim𝚽t1×nsuperscript𝐰𝑇superscript𝑓𝑖𝑚subscript𝚽𝑡superscript1𝑛\mathbf{w}^{T}\frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t}}\in\mathbb{% R}^{1\times n}bold_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_n end_POSTSUPERSCRIPT. Modern AD libraries (e.g., JAX, PyTorch, TensorFlow) directly provide VJP computations without forming the entire Jacobian matrix explicitly. Thus, the product 𝐰Tfim𝚽t1×nsuperscript𝐰𝑇superscript𝑓𝑖𝑚subscript𝚽𝑡superscript1𝑛\mathbf{w}^{T}\frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t}}\in\mathbb{% R}^{1\times n}bold_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_n end_POSTSUPERSCRIPT is efficiently computed on-the-fly via AD-generated VJP functions, greatly enhancing computational and memory efficiency (More details on the derivation of the hybrid adjoint backpropagation can be found in C).

The complete hybrid adjoint-based AD procedure for gradient back-propagation through an implicit layer can be summarized explicitly as follows: (1) solve the linear equation iteratively for 𝐰Tsuperscript𝐰𝑇\mathbf{w}^{T}bold_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT; (2) compute the gradient with respect to parameters 𝜽𝜽\boldsymbol{\theta}bold_italic_θ and 𝚽tsuperscriptsubscript𝚽𝑡\boldsymbol{\Phi}_{t}^{*}bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT using the obtained adjoint vector and provide it to the AD to continue further backpropagation.

Refer to caption
Figure 3: A schematic of forward and backward propagation through an implicit layer of the Im-PiNDiff model.

The schematic illustration of forward and backward passes through the implicit layer is provided in Fig. 3, and the complete algorithmic implementation steps are clearly outlined below in Algorithm 1.

Function Forward(𝚽init,𝚽t,𝛉subscriptsuperscript𝚽𝑖𝑛𝑖𝑡superscriptsubscript𝚽𝑡𝛉\boldsymbol{\Phi}^{*}_{init},\boldsymbol{\Phi}_{t}^{*},\boldsymbol{\theta}bold_Φ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i italic_n italic_i italic_t end_POSTSUBSCRIPT , bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_italic_θ):
      
      𝚽t+1RootFind(fim(𝚽t+1;𝚽t,𝜽)=0)superscriptsubscript𝚽𝑡1RootFindsuperscript𝑓𝑖𝑚subscript𝚽𝑡1superscriptsubscript𝚽𝑡𝜽0\boldsymbol{\Phi}_{t+1}^{*}\leftarrow\text{RootFind}\Big{(}f^{im}(\boldsymbol{% \Phi}_{t+1};\boldsymbol{\Phi}_{t}^{*},\boldsymbol{\theta})=0\Big{)}bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ← RootFind ( italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ; bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_italic_θ ) = 0 )
      return 𝚽t+1superscriptsubscript𝚽𝑡1\boldsymbol{\Phi}_{t+1}^{*}bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT
                                
Function Backward(𝚽t+1L,𝚽t+1,𝚽t,𝛉subscriptsuperscriptsubscript𝚽𝑡1𝐿superscriptsubscript𝚽𝑡1superscriptsubscript𝚽𝑡𝛉\partial_{\boldsymbol{\Phi}_{t+1}^{*}}L,\boldsymbol{\Phi}_{t+1}^{*},% \boldsymbol{\Phi}_{t}^{*},\boldsymbol{\theta}∂ start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_L , bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_italic_θ):
      
      𝚽t+1fim()vjp(fim(;𝚽t,𝜽))subscriptsuperscriptsubscript𝚽𝑡1superscript𝑓𝑖𝑚𝑣𝑗𝑝superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡𝜽\partial_{\boldsymbol{\Phi}_{t+1}^{*}}f^{im}(\cdot)\leftarrow vjp(f^{im}(\cdot% ;\boldsymbol{\Phi}_{t}^{*},\boldsymbol{\theta}))∂ start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( ⋅ ) ← italic_v italic_j italic_p ( italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( ⋅ ; bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_italic_θ ) ) \triangleright ba=[ab]subscript𝑏𝑎delimited-[]𝑎𝑏\partial_{b}a=\big{[}\frac{\partial a}{\partial b}\big{]}∂ start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT italic_a = [ divide start_ARG ∂ italic_a end_ARG start_ARG ∂ italic_b end_ARG ]
      𝚽tfim()vjp(fim(𝚽t+1;,𝜽))subscriptsuperscriptsubscript𝚽𝑡superscript𝑓𝑖𝑚𝑣𝑗𝑝superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡1𝜽\partial_{\boldsymbol{\Phi}_{t}^{*}}f^{im}(\cdot)\leftarrow vjp(f^{im}(% \boldsymbol{\Phi}_{t+1}^{*};\cdot,\boldsymbol{\theta}))∂ start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( ⋅ ) ← italic_v italic_j italic_p ( italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ; ⋅ , bold_italic_θ ) )
      𝜽fim()vjp(fim(𝚽t+1;𝚽t,))subscript𝜽superscript𝑓𝑖𝑚𝑣𝑗𝑝superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡1superscriptsubscript𝚽𝑡\partial_{\boldsymbol{\theta}}f^{im}(\cdot)\leftarrow vjp(f^{im}(\boldsymbol{% \Phi}_{t+1}^{*};\boldsymbol{\Phi}_{t}^{*},\cdot))∂ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( ⋅ ) ← italic_v italic_j italic_p ( italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ; bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , ⋅ ) )
      wt+1TRootFind(𝚽t+1fim(wT)=𝚽t+1L)subscriptsuperscriptw𝑇𝑡1RootFindsubscriptsuperscriptsubscript𝚽𝑡1superscript𝑓𝑖𝑚superscriptw𝑇subscriptsuperscriptsubscript𝚽𝑡1𝐿\textbf{w}^{T}_{t+1}\leftarrow\text{RootFind}\Big{(}\partial_{\boldsymbol{\Phi% }_{t+1}^{*}}f^{im}(\textbf{w}^{T})=-\partial_{\boldsymbol{\Phi}_{t+1}^{*}}L% \Big{)}w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ← RootFind ( ∂ start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) = - ∂ start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_L ) \triangleright f()im(wT)=wT[fim()]subscriptsuperscript𝑓𝑖𝑚superscriptw𝑇superscriptw𝑇delimited-[]superscript𝑓𝑖𝑚\partial f^{im}_{(\cdot)}(\textbf{w}^{T})=\textbf{w}^{T}\big{[}\frac{\partial f% ^{im}}{\partial(\cdot)}\big{]}∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ( ⋅ ) end_POSTSUBSCRIPT ( w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) = w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT [ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ ( ⋅ ) end_ARG ]
      𝚽t+1L[𝚽t𝚽t+1]𝚽tfim(wt+1T)subscriptsuperscriptsubscript𝚽𝑡1𝐿delimited-[]subscriptsuperscriptsubscript𝚽𝑡superscriptsubscript𝚽𝑡1subscriptsuperscriptsubscript𝚽𝑡superscript𝑓𝑖𝑚subscriptsuperscriptw𝑇𝑡1\partial_{\boldsymbol{\Phi}_{t+1}^{*}}L\cdot[\partial_{\boldsymbol{\Phi}_{t}^{% *}}{\boldsymbol{\Phi}_{t+1}^{*}}]\leftarrow\partial_{\boldsymbol{\Phi}_{t}^{*}% }f^{im}(\textbf{w}^{T}_{t+1})∂ start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_L ⋅ [ ∂ start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ] ← ∂ start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT )
      𝚽t+1L[𝜽𝚽t+1]𝜽fim(wt+1T)subscriptsuperscriptsubscript𝚽𝑡1𝐿delimited-[]subscript𝜽superscriptsubscript𝚽𝑡1subscript𝜽superscript𝑓𝑖𝑚subscriptsuperscriptw𝑇𝑡1\partial_{\boldsymbol{\Phi}_{t+1}^{*}}L\cdot[\partial_{\boldsymbol{\theta}}{% \boldsymbol{\Phi}_{t+1}^{*}}]\leftarrow\partial_{\boldsymbol{\theta}}f^{im}(% \textbf{w}^{T}_{t+1})∂ start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_L ⋅ [ ∂ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ] ← ∂ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT )
      return 𝚽t+1L[𝚽t𝚽t+1],𝚽t+1L[𝜽𝚽t+1]subscriptsuperscriptsubscript𝚽𝑡1𝐿delimited-[]subscriptsuperscriptsubscript𝚽𝑡superscriptsubscript𝚽𝑡1subscriptsuperscriptsubscript𝚽𝑡1𝐿delimited-[]subscript𝜽superscriptsubscript𝚽𝑡1\partial_{\boldsymbol{\Phi}_{t+1}^{*}}L\cdot[\partial_{\boldsymbol{\Phi}_{t}^{% *}}{\boldsymbol{\Phi}_{t+1}^{*}}],\ \ \partial_{\boldsymbol{\Phi}_{t+1}^{*}}L% \cdot[\partial_{\boldsymbol{\theta}}{\boldsymbol{\Phi}_{t+1}^{*}}]∂ start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_L ⋅ [ ∂ start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ] , ∂ start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_L ⋅ [ ∂ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ]
Algorithm 1 Algorithm for Implicit layer with adjoint-based backpropgation

2.3 Conditional neural fields for latent physics inference

To capture spatially and temporally varying latent physical quantities, such as unresolved PDE terms, parametric fields, or unmodeled operators, we incorporate conditional neural fields (CNFs) into the Im-PiNDiff framework. Neural fields (NF), also known as coordinate-based implicit neural representations, offer a flexible and expressive mechanism for modeling continuous functions and operator-valued mappings over space and time. These representations have gained significant traction in computer vision, graphics, and scientific machine learning due to their ability to encode high-frequency and nonstationary features in data with minimal inductive bias [41, 42, 43].

Within the context of PDE-constrained modeling, neural fields provide an elegant tool for parameterizing unknown coefficients or operators that vary over the spatial or spatiotemporal domain. The coordinate-continuous nature makes them particularly well suited for this task, as they can be queried at arbitrary spatial or temporal resolutions, ensuring mesh-invariant predictions and generalization to unseen domains. In this work, we employ a conditional formulation of neural fields, wherein the predicted physical field, e.g., unknown advection velocity 𝐮(𝐱,t)𝐮𝐱𝑡\mathbf{u}(\mathbf{x},t)bold_u ( bold_x , italic_t ) or diffusivity field k(𝐱)𝑘𝐱k(\mathbf{x})italic_k ( bold_x ), is conditioned on a latent embedding derived from auxiliary input. This conditioning allows the model to encode global contextual information, such as initial or boundary conditions, simulation settings, or observed response trajectories. As illustrated in Fig. 4, our architecture comprises three modules: a hypernetwork that maps condition vectors to latent codes, a linear projector that translates latent codes to NF parameters, and a base NF network that evaluates the inferred field at queried coordinates.

Refer to caption
Figure 4: Architecture of the conditional neural field (CNF) module for latent physics inference. A hypernetwork maps a contextual input vector 𝐜𝐜\mathbf{c}bold_c to a latent embedding 𝐡𝐡\mathbf{h}bold_h, which is linearly projected to generate the weights 𝜽bsubscript𝜽𝑏\boldsymbol{\theta}_{b}bold_italic_θ start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT of a SIREN-based neural field.

Specifically, given a conditioning vector 𝐜𝐜\mathbf{c}bold_c representing contextual input (e.g., initial/boundary condition encodings or low-dimensional representations of observed dynamics), a hypernetwork HyperNet𝜽hsubscriptHyperNetsubscript𝜽\text{HyperNet}_{\boldsymbol{\theta}_{h}}HyperNet start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT produces a latent code 𝐡dL𝐡superscriptsubscript𝑑𝐿\mathbf{h}\in\mathbb{R}^{d_{L}}bold_h ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_POSTSUPERSCRIPT,

𝐡=HyperNet𝜽h(𝐜),𝐡subscriptHyperNetsubscript𝜽𝐜\mathbf{h}=\text{HyperNet}_{\boldsymbol{\theta}_{h}}(\mathbf{c}),bold_h = HyperNet start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_c ) , (14)

which is subsequently projected via a learnable linear operator into the NF parameters 𝜽b=𝐖proj𝐡subscript𝜽𝑏subscript𝐖proj𝐡\boldsymbol{\theta}_{b}=\mathbf{W}_{\text{proj}}\cdot\mathbf{h}bold_italic_θ start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT = bold_W start_POSTSUBSCRIPT proj end_POSTSUBSCRIPT ⋅ bold_h. The base NF network, implemented as a sinusoidal representation network (SIREN) [42], then evaluates the spatial or spatiotemporal field as

𝐮(𝐱;𝐜)=SIREN𝜽b(𝐱),𝐮𝐱𝐜subscriptSIRENsubscript𝜽𝑏𝐱\mathbf{u}(\mathbf{x};\mathbf{c})=\text{SIREN}_{\boldsymbol{\theta}_{b}}(% \mathbf{x}),bold_u ( bold_x ; bold_c ) = SIREN start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_x ) , (15)

where 𝐮()𝐮\mathbf{u}(\cdot)bold_u ( ⋅ ) represents the inferred latent field, such as a velocity or forcing term. The use of SIREN allows the model to resolve fine-scale variations in the latent field and capture complex physical patterns across domains with smooth and expressive function approximations.

Crucially, the conditional neural field 𝐮𝜽b(𝐱;𝐜)subscript𝐮subscript𝜽𝑏𝐱𝐜\mathbf{u}_{\boldsymbol{\theta}_{b}}(\mathbf{x};\mathbf{c})bold_u start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_x ; bold_c ) is mixed directly with the known PDE operators to form the hybrid neural PDE operator nn[]subscript𝑛𝑛delimited-[]\mathscr{F}_{nn}[\cdot]script_F start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT [ ⋅ ] in the Im-PiNDiff solver, allowing it to affect the system dynamics during training. Because the CNF module is fully differentiable and compatible with our adjoint-based gradient propagation strategy, end-to-end learning remains tractable and memory-efficient. This integration enables the identification of latent physics from sparse indirect observations. That is, only a small part of the state trajectories {𝚽t}t=0Tsuperscriptsubscriptsubscript𝚽𝑡𝑡0𝑇\{\boldsymbol{\Phi}_{t}\}_{t=0}^{T}{ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT are observable, and the physical field itself is hidden.

3 Results and Discussion

3.1 Forward and nnverse modeling of spatiotemporal physics

We assessed the proposed Im-PiNDiff framework on two canonical spatiotemporal PDE systems: the advection-diffusion equation and the scalar Burgers’ equation. These case studies demonstrate the model’s capability to perform both forward prediction and inverse inference of latent physical fields under linear and nonlinear dynamics, respectively. In particular, the advection–diffusion system serves as a testbed for investigating the accuracy and stability of Im-PiNDiff in handling smooth transport-diffusion processes, while the Burgers’ equation probes its robustness in capturing nonlinear wave propagation and shock formation.

3.1.1 Advection–diffusion processes with steady advection fields

We begin with a 2D advection–diffusion system, which describes the spatiotemporal evolution of a scalar field ϕ(𝐱,t)italic-ϕ𝐱𝑡\phi(\mathbf{x},t)italic_ϕ ( bold_x , italic_t ) under combined effects of directional transport and diffusion,

ϕt=uxϕxuyϕy+k2ϕx2+k2ϕy2,italic-ϕ𝑡subscript𝑢𝑥italic-ϕ𝑥subscript𝑢𝑦italic-ϕ𝑦𝑘superscript2italic-ϕsuperscript𝑥2𝑘superscript2italic-ϕsuperscript𝑦2\frac{\partial\phi}{\partial t}=-{u}_{x}\frac{\partial\phi}{\partial x}-{u}_{y% }\frac{\partial\phi}{\partial y}+k\frac{\partial^{2}\phi}{\partial x^{2}}+k% \frac{\partial^{2}\phi}{\partial y^{2}},divide start_ARG ∂ italic_ϕ end_ARG start_ARG ∂ italic_t end_ARG = - italic_u start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT divide start_ARG ∂ italic_ϕ end_ARG start_ARG ∂ italic_x end_ARG - italic_u start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT divide start_ARG ∂ italic_ϕ end_ARG start_ARG ∂ italic_y end_ARG + italic_k divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_ϕ end_ARG start_ARG ∂ italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + italic_k divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_ϕ end_ARG start_ARG ∂ italic_y start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG , (16)

where k𝑘kitalic_k is the diffusion coefficient and uxsubscript𝑢𝑥{u}_{x}italic_u start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT, uysubscript𝑢𝑦{u}_{y}italic_u start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT are the advection coefficients in the x𝑥xitalic_x and y𝑦yitalic_y directions, respectively. Here, ux(𝐱)subscript𝑢𝑥𝐱{u}_{x}(\mathbf{x})italic_u start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( bold_x ) and uy(𝐱)subscript𝑢𝑦𝐱{u}_{y}(\mathbf{x})italic_u start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( bold_x ) are treated as unknown spatial fields to be inferred during Im-PiNDiff training, where partial observations of the state variable ϕ(𝐱,t)italic-ϕ𝐱𝑡\phi(\mathbf{x},t)italic_ϕ ( bold_x , italic_t ) are used as labels. These unknown fields are parameterized by CNFs.

To generate training and testing data, we employed a high resolution finite-volume (FV) solver with fourth-order Runge–Kutta (RK4) time integration and a time step of 103superscript10310^{-3}10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT s. The physical domain of size 2×1212\times 12 × 1 was discretized using a 128×6412864128\times 64128 × 64 Cartesian mesh. A set of ten initial conditions was randomly generated from Gaussian processes (GPs) on a coarse grid 30×10301030\times 1030 × 10 and projecting them onto the simulation mesh. These GP fields were constructed using radial basis kernels with varying length scales to promote diversity in initial configurations. Ground truth advection velocity fields were generated as smooth superpositions of cosine functions (details in D) and similarly projected to the simulation grid. For training, only a very limited number of state observations, specifically, two snapshots of ϕitalic-ϕ\phiitalic_ϕ, were provided, simulating a data-sparse regime. The model was then evaluated on unseen initial conditions to assess its ability to reconstruct the full spatiotemporal state (ϕitalic-ϕ\phiitalic_ϕ) trajectories over extended prediction horizons, while simultaneously inferring the hidden advection velocity fields from these sparse, indirect observations.

The Im-PiNDiff model was constructed using the FV discretization of the governing PDEs, where the unknown spatial advection velocity fields are modeled as CNFs parameterized continuously over space and conditioned on time. To achieve stable and accurate long-term state propagation, the temporal autoregression is through the implicit layer using a second-order Crank–Nicolson scheme,

𝚽ti+1(𝜽)=𝚽ti(𝜽)+ti+1ti2[nn[𝚽ti+1(𝜽);𝜽]+nn[𝚽ti(𝜽);𝜽]],i=0,,T,formulae-sequencesubscript𝚽subscript𝑡𝑖1𝜽subscript𝚽subscript𝑡𝑖𝜽subscript𝑡𝑖1subscript𝑡𝑖2delimited-[]subscript𝑛𝑛subscript𝚽subscript𝑡𝑖1𝜽𝜽subscript𝑛𝑛subscript𝚽subscript𝑡𝑖𝜽𝜽𝑖0𝑇\boldsymbol{\Phi}_{t_{i+1}}(\boldsymbol{\theta})=\boldsymbol{\Phi}_{t_{i}}(% \boldsymbol{\theta})+\frac{t_{i+1}-t_{i}}{2}\Big{[}\mathscr{F}_{nn}\big{[}% \boldsymbol{\Phi}_{t_{i+1}}(\boldsymbol{\theta});\boldsymbol{\theta}\big{]}+% \mathscr{F}_{nn}\big{[}\boldsymbol{\Phi}_{t_{i}}(\boldsymbol{\theta});% \boldsymbol{\theta}\big{]}\Big{]},\quad i=0,\dots,T,bold_Φ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_θ ) = bold_Φ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_θ ) + divide start_ARG italic_t start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT - italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG [ script_F start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT [ bold_Φ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_θ ) ; bold_italic_θ ] + script_F start_POSTSUBSCRIPT italic_n italic_n end_POSTSUBSCRIPT [ bold_Φ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_θ ) ; bold_italic_θ ] ] , italic_i = 0 , … , italic_T , (17)

propagating state from tisubscript𝑡𝑖t_{i}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT to ti+1subscript𝑡𝑖1t_{i+1}italic_t start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT. The neural model uses a relative large time step of 102superscript10210^{-2}10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT s, an order of magnitude larger than the step size used in generating the training data. At each time step, Eq. (17) requires solving a nonlinear system due to its implicit dependence on 𝚽ti+1subscript𝚽subscript𝑡𝑖1\boldsymbol{\Phi}_{t_{i+1}}bold_Φ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT, which is accomplished via the Biconjugate Gradient Stabilized (BiCGStab) method. To enable efficient end-to-end training, we applied the hybrid adjoint-based AD backpropagation algorithm introduced in Section 2.

We first evaluated the Im-PiNDiff framework for the scenario where the advection velocity fields were assumed to be spatially varying but temporally invariant, i.e., ux=ux(𝐱)subscript𝑢𝑥subscript𝑢𝑥𝐱{u}_{x}={u}_{x}(\mathbf{x})italic_u start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT = italic_u start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( bold_x ) and uy=uy(𝐱)subscript𝑢𝑦subscript𝑢𝑦𝐱{u}_{y}={u}_{y}(\mathbf{x})italic_u start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT = italic_u start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( bold_x ). The model was trained using only two snapshots of the state field: the initial condition 𝚽0subscript𝚽0\boldsymbol{\Phi}_{0}bold_Φ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and an observation at t=0.05𝑡0.05t=0.05italic_t = 0.05 s, as indicated in the top-right panel of Fig. 5.

Refer to caption
Figure 5: Forward prediction and inverse inference using Im-PiNDiff on the advection–diffusion problem with steady advection fields. The model is trained using only two snapshots ϕtt=0,0.05bold-italic-ϕ𝑡𝑡00.05{\boldsymbol{\phi}t}{t=0,0.05}bold_italic_ϕ italic_t italic_t = 0 , 0.05. Left: inferred steady advection fields ux(𝐱)subscript𝑢𝑥𝐱u_{x}(\mathbf{x})italic_u start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( bold_x ) and uy(𝐱)subscript𝑢𝑦𝐱u_{y}(\mathbf{x})italic_u start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( bold_x ) compared with ground truth. Right: predicted scalar field ϕ(𝐱,t)italic-ϕ𝐱𝑡\phi(\mathbf{x},t)italic_ϕ ( bold_x , italic_t ) at future times, showing excellent agreement with ground truth across the forecast horizon.

After 2000 epochs of training, the model produced accurate forward predictions of the scalar field ϕ(𝐱,t)bold-italic-ϕ𝐱𝑡\boldsymbol{\phi}(\mathbf{x},t)bold_italic_ϕ ( bold_x , italic_t ) over the extended horizon 0.05,0.1,0.15,0.20.050.10.150.2{0.05,0.1,0.15,0.2}0.05 , 0.1 , 0.15 , 0.2 s and simultaneously inferred the underlying steady advection fields. As shown in Fig. 5, the predicted scalar field trajectories closely match the ground truth with a relative error of approximately 2%, while the advection fields uxsubscript𝑢𝑥{u}_{x}italic_u start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT and uysubscript𝑢𝑦{u}_{y}italic_u start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT are accurately inferred as well, with 3% error compared to the true fields.

Refer to caption
Figure 6: Im-PiNDiff predictions of ϕ(𝐱,t)italic-ϕ𝐱𝑡\phi(\mathbf{x},t)italic_ϕ ( bold_x , italic_t ) for testing initial conditions with time-invariant advection fields.

The trained model was further tested on out-of-distribution (OOD) initial conditions generated using radial basis kernels with characteristic length scales different from those used during training. As illustrated in Fig. 6, the model accurately predicted the spatiotemporal evolution of the scalar field ϕ(𝐱,t)bold-italic-ϕ𝐱𝑡\boldsymbol{\phi}(\mathbf{x},t)bold_italic_ϕ ( bold_x , italic_t ) over the full time horizon [0,0.2]00.2[0,0.2][ 0 , 0.2 ] s, despite having never seen such initial states during training. The predictions remain stable and accurate across time, with relative errors generally below 2%, indicating strong robustness to randomly generated unseen initial conditions.

These results highlight the model’s capacity to perform robust forward/inverse modeling from extremely limited data, accurately reconstructing both state dynamics and hidden physics. Notably, no direct observations of the advection fields were used during training, underscoring the effectiveness of the CNF and hybrid adjoint-AD gradient propagation over the entire program in inferring latent physics from sparse indirect measurements.

3.1.2 Advection–diffusion processes with dynamic advection fields

To further evaluate the robustness of Im-PiNDiff in modeling non-stationary physics, we consider a more challenging setting where the advection fields vary in both space and time, i.e., ux(𝐱,t)subscript𝑢𝑥𝐱𝑡{u}_{x}(\mathbf{x},t)italic_u start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( bold_x , italic_t ) and uy(𝐱,t)subscript𝑢𝑦𝐱𝑡{u}_{y}(\mathbf{x},t)italic_u start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( bold_x , italic_t ). This setting poses a significant challenge for spatiotemporal inversion, as the time-varying latent fields are inferred from sparse and irregular data.

Refer to caption
Figure 7: Forward prediction and inverse inference using Im-PiNDiff on the advection–diffusion problem with time-varying advection fields.
Refer to caption
Figure 8: Im-PiNDiff predictions of ϕ(𝐱,t)italic-ϕ𝐱𝑡\phi(\mathbf{x},t)italic_ϕ ( bold_x , italic_t ) for testing initial conditions with time-varying advection fields.

The training set consists of eight snapshots sampled at nonuniform time intervals ttrain=0,0.05,0.102,0.15,0.201,0.25,0.298,0.35subscript𝑡train00.050.1020.150.2010.250.2980.35{t}_{\text{train}}={0,0.05,0.102,0.15,0.201,0.25,0.298,0.35}italic_t start_POSTSUBSCRIPT train end_POSTSUBSCRIPT = 0 , 0.05 , 0.102 , 0.15 , 0.201 , 0.25 , 0.298 , 0.35 s over a total simulation horizon of T=0.4𝑇0.4T=0.4italic_T = 0.4 s. This choice reflects realistic scenarios where sensor data may be irregularly sampled in time. The model is trained for 10,000 epochs to jointly learn both the forward state evolution ϕ(𝐱,t)italic-ϕ𝐱𝑡\phi(\mathbf{x},t)italic_ϕ ( bold_x , italic_t ) and the hidden time-varying advection fields ux(𝐱,t)subscript𝑢𝑥𝐱𝑡{u}_{x}(\mathbf{x},t)italic_u start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( bold_x , italic_t ) and uy(𝐱,t)subscript𝑢𝑦𝐱𝑡{u}_{y}(\mathbf{x},t)italic_u start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( bold_x , italic_t ). The learned model is then used to reconstruct full spatiotemporal fields at unobserved time points t=0.02,0.07,0.12,0.17,0.22,0.27,0.32,0.37,0.4𝑡0.020.070.120.170.220.270.320.370.4t={0.02,0.07,0.12,0.17,0.22,0.27,0.32,0.37,0.4}italic_t = 0.02 , 0.07 , 0.12 , 0.17 , 0.22 , 0.27 , 0.32 , 0.37 , 0.4, which are also sampled irregularly.

As shown in Fig. 7, the Im-PiNDiff model accurately recovers the ground truth dynamics and latent velocity fields. The left panels show the inferred uxsubscript𝑢𝑥{u}_{x}italic_u start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT and uysubscript𝑢𝑦{u}_{y}italic_u start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT fields in comparison with their ground truth counterparts, along with relative error maps across time. Despite the inherent ill-posedness of recovering time-dependent vector fields from limited state observations, the inferred advection fields capture the spatial-temporal structures reasonably well, with a relative error around 10%. The right panels show the predicted scalar field ϕ(𝐱,t)italic-ϕ𝐱𝑡\phi(\mathbf{x},t)italic_ϕ ( bold_x , italic_t ), achieving high accuracy with a relative error consistently below 2% across all prediction times. Generalization performance is further evaluated using OOD initial conditions, generated from unseen Gaussian process realizations. The predictions on these OOD cases, shown in Fig. 8, demonstrate excellent agreement with ground truth, confirming that the Im-PiNDiff model generalizes robustly to new initializations with high accuracy.

3.1.3 Scalar Burgers’ dynamics with spatially varying viscosity fields

We further evaluate the Im-PiNDiff framework on the 2D scalar Burgers’ equation to test its capability in recovering spatially varying latent physical parameters under nonlinear dynamics,

ut=uuxuuy+ν2ux2+ν2uy2,𝑢𝑡𝑢𝑢𝑥𝑢𝑢𝑦𝜈superscript2𝑢superscript𝑥2𝜈superscript2𝑢superscript𝑦2\frac{\partial{u}}{\partial t}=-{u}\frac{\partial{u}}{\partial x}-{u}\frac{% \partial{u}}{\partial y}+{\nu}\frac{\partial^{2}{u}}{\partial x^{2}}+{\nu}% \frac{\partial^{2}{u}}{\partial y^{2}},divide start_ARG ∂ italic_u end_ARG start_ARG ∂ italic_t end_ARG = - italic_u divide start_ARG ∂ italic_u end_ARG start_ARG ∂ italic_x end_ARG - italic_u divide start_ARG ∂ italic_u end_ARG start_ARG ∂ italic_y end_ARG + italic_ν divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_u end_ARG start_ARG ∂ italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + italic_ν divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_u end_ARG start_ARG ∂ italic_y start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG , (18)

where u(𝐱,t)𝑢𝐱𝑡{u}(\mathbf{x},t)italic_u ( bold_x , italic_t ) is the scalar velocity field and ν(𝐱)𝜈𝐱{\nu}(\mathbf{x})italic_ν ( bold_x ) denotes the unknown spatially varying viscosity field. In this setting, the convective nonlinearity leads to steep gradients and localized structures, making the recovery of latent viscosity fields from sparse observations particularly challenging. Note that the viscosity is set to be of order 102superscript10210^{-2}10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT; higher values would suppress the convective dynamics and render the solution diffusion-dominated.

The Im-PiNDiff model is trained using only four snapshots of the velocity field, without any direct supervision on the viscosity. The CNF module is tasked with recovering 𝝂(𝐱)𝝂𝐱\boldsymbol{\nu}(\mathbf{x})bold_italic_ν ( bold_x ) from the observed dynamics.

Refer to caption
Figure 9: Inference of spatially varying viscosity field ν(𝐱)𝜈𝐱\nu(\mathbf{x})italic_ν ( bold_x ) and prediction of velocity field u(𝐱,t)𝑢𝐱𝑡{u}(\mathbf{x},t)italic_u ( bold_x , italic_t ) for the 2D scalar Burgers’ equation using Im-PiNDiff. The model is trained on sparse observations of u𝑢{u}italic_u and successfully reconstructs both the hidden viscosity field and the spatiotemporal dynamics.

Figure 9 shows the results of Im-PiNDiff for the scalar Burgers’ dynamics with spatially varying viscosity. The left panel presents the inferred steady viscosity field ν(𝐱)𝜈𝐱\nu(\mathbf{x})italic_ν ( bold_x ) in comparison with the ground truth, alongside its corresponding relative error. Although recovering viscosity from indirect state observations is highly ill-posed, particularly in this regime where ν(𝐱)𝜈𝐱\nu(\mathbf{x})italic_ν ( bold_x ) is of order 102superscript10210^{-2}10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT and exerts only a weak influence on the dynamics, the inferred 𝝂(𝐱)𝝂𝐱\boldsymbol{\nu}(\mathbf{x})bold_italic_ν ( bold_x ) captures the dominant spatial patterns and exhibits reasonable structural agreement with the true field. Some deviation is observed in regions of low sensitivity, resulting in higher relative errors (up to 28%). The low sensitivity is evident from the velocity prediction results. As shown in the right panel, the model achieves highly accurate predictions of the velocity field, with a relative error around 4%. These results demonstrate that Im-PiNDiff remains robust in recovering hidden spatial parameters while preserving predictive fidelity in nonlinear PDE systems.

Further, the trained model was tested on unseen initial conditions.

Refer to caption
Figure 10: Im-PiNDiff predictions of Burgers’ dyanmics on two testing (unseen) initial conditions.

As shown in Fig. 10, the model accurately predicts the spatiotemporal evolution of the velocity field 𝐮(𝐱,t)𝐮𝐱𝑡\mathbf{u}(\mathbf{x},t)bold_u ( bold_x , italic_t ) across both test cases, achieving relative errors consistently around 4%. These results confirm that the proposed framework generalizes robustly to new initializations, despite being trained under limited data and an ill-posed inference setting.

3.2 Temporal error accumulation and stability analysis

An important motivation for the proposed Im-PiNDiff framework lies in its ability to perform robust learning and stable long-horizon forecasting, while mitigating error accumulation. This advantage stems from the use of implicit autoregressive forward passes, which have long been favored in classical numerical analysis for their superior stability properties, especially in stiff or convection-dominated systems. To systematically quantify and compare the temporal error accumulation behavior, we conduct a controlled study using the advection–diffusion problem with steady advection fields. We benchmark the performance of Im-PiNDiff against its explicit counterpart, referred to as Ex-PiNDiff, which follows the original PiNDiff formulation with a recurrent explicit forward pass [13]. Both models are trained on the same dataset using the same model architecture, differing only in the time-stepping mechanism. We investigate two key aspects: (1) error accumulation over time for a fixed time-step size and (2) sensitivity of the final prediction error to varying temporal resolutions. The results are shown in Fig. 11.

Refer to caption
(a) Accumulation of relative error in the ΦtsubscriptΦ𝑡\Phi_{t}roman_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over simulation time
Refer to caption
(b) L1 relative error in the Φ0.2subscriptΦ0.2\Phi_{0.2}roman_Φ start_POSTSUBSCRIPT 0.2 end_POSTSUBSCRIPT at last time for various time-step sizes
Figure 11: Accumulation of relative error in the ϕitalic-ϕ\phiitalic_ϕ fields for the Ex-PiNDiff (Ex) and Im-PiNDiff (Im) models, evaluated on various test cases of Advection-Diffusion.

Figure 11(a) reports the accumulation of L1-norm relative error in the scalar field ϕitalic-ϕ\phiitalic_ϕ over the simulation horizon of T=0.2𝑇0.2T=0.2italic_T = 0.2 s, comparing three models: Ex-PiNDiff with Δt=102Δ𝑡superscript102\Delta t=10^{-2}roman_Δ italic_t = 10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT s (black), Ex-PiNDiff with Δt=103Δ𝑡superscript103\Delta t=10^{-3}roman_Δ italic_t = 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT s (blue), and Im-PiNDiff with Δt=102Δ𝑡superscript102\Delta t=10^{-2}roman_Δ italic_t = 10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT s (red). As the simulation progresses, the explicit model with coarse time steps (Δt=102Δ𝑡superscript102\Delta t=10^{-2}roman_Δ italic_t = 10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT s) exhibits significant error growth, reflecting compounding integration and learning errors. In contrast, the implicit Im-PiNDiff maintains a nearly constant error profile throughout the rollout, highlighting its temporal stability. Notably, the implicit model achieves comparable or better accuracy than the explicit model even when trained at a tenfold coarser temporal resolution. Figure 11(b) further quantifies the effect of time-step size on final prediction accuracy by plotting the L1 relative error in ϕitalic-ϕ\phiitalic_ϕ at t=0.2𝑡0.2t=0.2italic_t = 0.2 s across various ΔtΔ𝑡\Delta troman_Δ italic_t values. The explicit model shows a steep increase in error as ΔtΔ𝑡\Delta troman_Δ italic_t increases, underscoring its sensitivity to temporal resolution. In contrast, Im-PiNDiff remains remarkably robust, showing negligible degradation in accuracy even as ΔtΔ𝑡\Delta troman_Δ italic_t grows. This result confirms that the implicit layer stabilizes long-horizon predictions, suppresses numerical drift, and allows for larger autoregressive steps without compromising predictive quality.

Taken together, these findings illustrate the key strength of the Im-PiNDiff framework in preserving predictive accuracy and stability over extended time horizons. By decoupling learning stability from time-step resolution, Im-PiNDiff not only ensures physically consistent forecasting but also enables significant computational savings, a benefit explored further in the subsequent section on computational cost analysis.

3.3 Computational cost analysis

To understand the efficiency and scalability of Im-PiNDiff models, we analyze the computational and memory complexity of Im-PiNDiff relative to Ex-PiNDiff under various settings. In the Ex-PiNDiff model, the system evolves through explicit autoregressive updates, forming a forward trajectory of the form,

𝚽0𝚽Δt𝚽t𝚽t+Δt𝚽TΔt𝚽TL,subscript𝚽0subscript𝚽Δ𝑡subscript𝚽𝑡subscript𝚽𝑡Δ𝑡subscript𝚽𝑇Δ𝑡subscript𝚽𝑇𝐿\boldsymbol{\Phi}_{0}\rightarrow\boldsymbol{\Phi}_{\Delta t}\rightarrow\dots% \rightarrow\boldsymbol{\Phi}_{t}\rightarrow\boldsymbol{\Phi}_{t+\Delta t}% \rightarrow\dots\rightarrow\boldsymbol{\Phi}_{T-\Delta t}\rightarrow% \boldsymbol{\Phi}_{T}\rightarrow L,bold_Φ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT → bold_Φ start_POSTSUBSCRIPT roman_Δ italic_t end_POSTSUBSCRIPT → … → bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → bold_Φ start_POSTSUBSCRIPT italic_t + roman_Δ italic_t end_POSTSUBSCRIPT → … → bold_Φ start_POSTSUBSCRIPT italic_T - roman_Δ italic_t end_POSTSUBSCRIPT → bold_Φ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT → italic_L , (19)

where each transition step requires storing the full state 𝚽tnsubscript𝚽𝑡superscript𝑛\boldsymbol{\Phi}_{t}\in\mathbb{R}^{n}bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT and its associated computational graph for backpropagation through time. In contrast, the Im-PiNDiff model propagates the state implicitly by solving a nonlinear fixed-point equation at each time step. A naive AD-based implementation that unrolls K𝐾Kitalic_K solver iterations per step as,

𝚽0{𝚽Δt[0]𝚽Δt[K]}𝚽Δt{𝚽t[0]𝚽t[K]}𝚽t{𝚽T[0]𝚽T[K]}𝚽TL.subscript𝚽0superscriptsubscript𝚽Δ𝑡delimited-[]0superscriptsubscript𝚽Δ𝑡delimited-[]𝐾superscriptsubscript𝚽Δ𝑡superscriptsubscript𝚽𝑡delimited-[]0superscriptsubscript𝚽𝑡delimited-[]𝐾superscriptsubscript𝚽𝑡superscriptsubscript𝚽𝑇delimited-[]0superscriptsubscript𝚽𝑇delimited-[]𝐾superscriptsubscript𝚽𝑇𝐿\boldsymbol{\Phi}_{0}\rightarrow\big{\{}\boldsymbol{\Phi}_{\Delta t}^{[0]}% \cdots\boldsymbol{\Phi}_{\Delta t}^{[K]}\big{\}}\rightarrow\boldsymbol{\Phi}_{% \Delta t}^{*}\cdots\big{\{}\boldsymbol{\Phi}_{t}^{[0]}\cdots\boldsymbol{\Phi}_% {t}^{[K]}\big{\}}\rightarrow\boldsymbol{\Phi}_{t}^{*}\cdots\big{\{}\boldsymbol% {\Phi}_{T}^{[0]}\cdots\boldsymbol{\Phi}_{T}^{[K]}\big{\}}\rightarrow% \boldsymbol{\Phi}_{T}^{*}\rightarrow L.bold_Φ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT → { bold_Φ start_POSTSUBSCRIPT roman_Δ italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 0 ] end_POSTSUPERSCRIPT ⋯ bold_Φ start_POSTSUBSCRIPT roman_Δ italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ italic_K ] end_POSTSUPERSCRIPT } → bold_Φ start_POSTSUBSCRIPT roman_Δ italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ⋯ { bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 0 ] end_POSTSUPERSCRIPT ⋯ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ italic_K ] end_POSTSUPERSCRIPT } → bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ⋯ { bold_Φ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 0 ] end_POSTSUPERSCRIPT ⋯ bold_Φ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ italic_K ] end_POSTSUPERSCRIPT } → bold_Φ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT → italic_L . (20)

where each 𝚽t[k]𝚽superscript𝑡delimited-[]𝑘\boldsymbol{\Phi}{t}^{[k]}bold_Φ italic_t start_POSTSUPERSCRIPT [ italic_k ] end_POSTSUPERSCRIPT denotes the k𝑘kitalic_k-th iteration of the nonlinear solver at time t𝑡titalic_t. This naive Im-PiNDiff approach incurs memory and runtime costs that scale linearly with both the number of time steps and the number of solver iterations K𝐾Kitalic_K. To overcome this limitation, the proposed hybrid training strategy combines adjoint-state methods with reverse-mode AD, avoiding the need to store intermediate iterations by solving a single linear system via the implicit function theorem. The number of solver steps K~~𝐾\tilde{K}over~ start_ARG italic_K end_ARG can be determined dynamically via convergence criteria. This approach decouples memory usage from K~~𝐾\tilde{K}over~ start_ARG italic_K end_ARG while still retaining the flexibility of accurate, adaptive inner solvers, yielding significant gains in both scalability and efficiency. Table 1 summarizes the computational complexity of different strategies in terms of memory usage and training time.

Memory footprint Training time
Ex-PiNDiff O((T/Δt)n)𝑂𝑇Δ𝑡𝑛O\big{(}(T/\Delta t)n\big{)}italic_O ( ( italic_T / roman_Δ italic_t ) italic_n ) O((T/Δt)n3)𝑂𝑇Δ𝑡superscript𝑛3O\big{(}(T/\Delta t)n^{3}\big{)}italic_O ( ( italic_T / roman_Δ italic_t ) italic_n start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT )
Naive Im-PiNDiff O((T/Δt)Kn)𝑂𝑇Δ𝑡𝐾𝑛O\big{(}(T/\Delta t)Kn\big{)}italic_O ( ( italic_T / roman_Δ italic_t ) italic_K italic_n ) O((T/Δt)Kn3)𝑂𝑇Δ𝑡𝐾superscript𝑛3O\big{(}(T/\Delta t)Kn^{3}\big{)}italic_O ( ( italic_T / roman_Δ italic_t ) italic_K italic_n start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT )
Im-PiNDiff O((T/Δt)n)𝑂𝑇Δ𝑡𝑛O\big{(}(T/\Delta t)n\big{)}italic_O ( ( italic_T / roman_Δ italic_t ) italic_n ) O((T/Δt)K~n3)𝑂𝑇Δ𝑡~𝐾superscript𝑛3O\big{(}(T/\Delta t)\tilde{K}n^{3}\big{)}italic_O ( ( italic_T / roman_Δ italic_t ) over~ start_ARG italic_K end_ARG italic_n start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT )
Table 1: Computational cost for Ex-PiNDiff and Im-PiNDiff models.

The key observation is that the memory footprint of the hybrid Im-PiNDiff model is independent of K~~𝐾\tilde{K}over~ start_ARG italic_K end_ARG,

Refer to caption
Figure 12: Memory consumption and training time for the advection-diffusion case. (Ex: Ex-PiNDiff, Im: Im-PiNDiff, Im-Cp: Im-PiNDiff w/ Checkpoint.)

unlike the naive version. This enables substantial memory savings without compromising expressivity or stability. Furthermore, since K~~𝐾\tilde{K}over~ start_ARG italic_K end_ARG is adaptively determined, the total training time is often lower than that of the fixed-K𝐾Kitalic_K naive strategy.

Figure 12 presents the empirical computational cost for the advection–diffusion case, showing peak memory and wall-clock training time per epoch over a range of time-step sizes. For small ΔtΔ𝑡\Delta troman_Δ italic_t values, Im-PiNDiff and Ex-PiNDiff exhibit similar memory usage, but the implicit model typically incurs longer training times due to iterative solver overhead. However, Im-PiNDiff supports significantly larger ΔtΔ𝑡\Delta troman_Δ italic_t while maintaining predictive accuracy. For instance, to achieve a relative error below 0.05%percent0.050.05\%0.05 %, Ex-PiNDiff requires Δt=1×104Δ𝑡1superscript104\Delta t=1\times 10^{-4}roman_Δ italic_t = 1 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT s, whereas Im-PiNDiff remain this accuracy even at Δt=2×103Δ𝑡2superscript103\Delta t=2\times 10^{-3}roman_Δ italic_t = 2 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT s. At this larger step size, the number of rollout steps is reduced by a factor of 20, yielding a 14× reduction in memory usage and a 3× speedup in training time. For most practical problems, the effective inequality (T/ΔtIm-PiNDiff)K~<(T/ΔtEx-PiNDiff)𝑇Δsubscript𝑡Im-PiNDiff~𝐾𝑇Δsubscript𝑡Ex-PiNDiff(T/\Delta t_{\text{Im-PiNDiff}})\tilde{K}<(T/\Delta t_{\text{Ex-PiNDiff}})( italic_T / roman_Δ italic_t start_POSTSUBSCRIPT Im-PiNDiff end_POSTSUBSCRIPT ) over~ start_ARG italic_K end_ARG < ( italic_T / roman_Δ italic_t start_POSTSUBSCRIPT Ex-PiNDiff end_POSTSUBSCRIPT ) always holds, yielding superior performance. In the case of the naive Im-PiNDiff model, the memory usage increases linearly with the specified inner iteration K𝐾Kitalic_K, and for K=32𝐾32K=32italic_K = 32, it consumes around 34 GBs. Since training time scales directly with simulation execution time, the findings of training time are equally applicable to inference time. To further reduce memory requirements, we apply checkpointing to the implicit model (Im-Cp). This technique stores selected intermediate states and recomputes others during the backward pass, significantly reducing memory usage while incurring a modest increase in computational time. As shown in Fig. 12, checkpointing flattens the memory growth curve across increasing simulation lengths, making long-horizon training feasible on resource-constrained hardware.

To further demonstrate the scalability and versatility of the hybrid adjoint-based gradient propagation strategy, we investigated a multi-physics reaction–diffusion problem arising in Chemical Vapor Infiltration (CVI) modeling. Specifically, we applied our approach to the PiNDiff-CVI framework developed in [14], which simulates porous infiltration by solving two tightly coupled PDEs: an elliptic Poisson equation governing steady-state molarity distribution and a hyperbolic transport equation modeling time-evolving deposition dynamics. The elliptic component introduces an additional layer of complexity, as it necessitates an inner numerical solve at each time step, leading to a nested bilevel optimization structure (More details can be found in E). In the original implementation of PiNDiff-CVI [14], the Poisson solver was handled using a fixed number of unrolled iterations, incurring considerable memory overhead during training due to the explicit differentiation of each inner step. In contrast, we replaced this naive backpropagation scheme with our proposed adjoint-based differentiation strategy, which computes gradients through the elliptic solver using the implicit function theorem, thus avoiding the need to store intermediate iterates. For the hyperbolic time integration component, we employed checkpointing to reduce memory requirements by recomputing selected intermediate states during the backward pass. This combination of adjoint differentiation and checkpointing significantly improves both efficiency and scalability. The comparative performance of the original Ex-PiNDiff, Im-PiNDiff, and Im-PiNDiff with checkpointing (Im-Cp) is summarized in Fig. 13.

Refer to caption
Figure 13: Comparision of computational cost for Ex-PiNDiff, Im-PiNDiff, and Im-PiNDiff-Cp for CVI modeling.(Ex: Ex-PiNDiff, Im: Im-PiNDiff, Im-Cp: Im-PiNDiff w/t Checkpoint)

The hybrid Im-PiNDiff model achieved nearly a 2× reduction in both memory consumption and wall-clock training time relative to the explicit baseline, without sacrificing predictive accuracy. Moreover, the integration of checkpointing further reduced memory usage to approximately 696 MB, an order of magnitude lower than the Ex-PiNDiff baseline, while incurring only a modest increase in computational time. These results highlight the practicality of hybrid gradient propagation strategies for large-scale, stiff, or tightly coupled PDE systems, where traditional AD approaches are often prohibitive.

4 Conclusion

This work presents a hybrid neural-physics modeling framework, termed Im-PiNDiff, for complex spatiotemporal dynamics. By introducing implicit neural differential layers into the PiNDiff architecture, we address the challenge of numerical instability and error accumulation inherent in explicit recurrent formulations. The use of implicit time stepping significantly improves the temporal stability and long-horizon predictive accuracy.

A key innovation of the proposed framework lies in its hybrid gradient propagation strategy, which integrates adjoint-based implicit differentiation with reverse-mode AD. This approach decouples gradient computation from the number of solver iterations, enabling memory-efficient training without compromising accuracy. Moreover, we incorporate checkpointing schemes to further reduce the peak memory footprint, making the framework viable for large-scale, long-time simulations on memory-constrained hardware. Together, these algorithmic advances allow Im-PiNDiff to scale to previously intractable problem regimes, achieving superior efficiency and stability over traditional AD-based implementations.

For latent physics inference, we leverage CNFs to parameterize spatially and temporally varying physical quantities, which enables the model to recover unobserved fields or operators from sparse, indirect measurements, expanding the applicability of PiNDiff models to scenarios where direct supervision is unavailable or limited. Extensive numerical experiments, including linear advection-diffusion, nonlinear Burgers’ dynamics, and multiphysics CVI processes, demonstrate the proposed framework’s effectiveness in both forward and inverse modeling tasks under challenging data and conditions.

Overall, Im-PiNDiff represents a significant step toward enabling stable, efficient, and generalizable hybrid modeling for real-world scientific systems. The combination of implicit architectures, adjoint-based training, and neural field parameterizations offers a flexible and robust paradigm for next-generation scientific machine learning. Future extensions will explore adaptive solvers, stochastic PDEs, and coupling with experimental data streams to further broaden the utility of this framework in data-constrained and multiscale physical modeling settings.

Declaration of competing interests

The authors declare no competing interests.

Data availability

All data needed to evaluate the conclusions in the paper are either present in the paper or can be regenerated by the code provided.

Acknowledgment

The authors would like to acknowledge the funds from the Air Force Office of Scientific Research (AFOSR), United States of America under award number FA9550-22-1-0065. JXW would also like to acknowledge the funding support from the Office of Naval Research under award number N00014-23-1-2071 and the National Science Foundation under award number OAC-2047127 in supporting this study.

References

  • [1] M. Raissi, P. Perdikaris, G. Karniadakis, Physics-informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations, Journal of Computational Physics 378 (2019) 686–707.
  • [2] L. Sun, H. Gao, S. Pan, J.-X. Wang, Surrogate modeling for fluid flows based on physics-constrained deep learning without simulation data, Computer Methods in Applied Mechanics and Engineering 361 (2020) 112732.
  • [3] L. Sun, J.-X. Wang, Physics-constrained bayesian neural network for fluid flow reconstruction with sparse and noisy data, Theoretical and Applied Mechanics Letters 10 (3) (2020) 161–169.
  • [4] Z. Li, N. B. Kovachki, K. Azizzadenesheli, K. Bhattacharya, A. Stuart, A. Anandkumar, et al., Fourier neural operator for parametric partial differential equations, in: International Conference on Learning Representations, 2021.
  • [5] L. Lu, P. Jin, G. Pang, Z. Zhang, G. E. Karniadakis, Learning nonlinear operators via deeponet based on the universal approximation theorem of operators, Nature Machine Intelligence 3 (3) (2021) 218–229.
  • [6] S. Wang, H. Wang, P. Perdikaris, Learning the solution operator of parametric partial differential equations with physics-informed deeponets, Science advances 7 (40) (2021) eabi8605.
  • [7] S. L. Brunton, J. L. Proctor, J. N. Kutz, Discovering governing equations from data by sparse identification of nonlinear dynamical systems, Proceedings of the national academy of sciences 113 (15) (2016) 3932–3937.
  • [8] Z. Chen, Y. Liu, H. Sun, Physics-informed learning of governing equations from scarce data, Nature communications 12 (1) (2021) 1–13.
  • [9] L. Sun, D. Z. Huang, H. Sun, J.-X. Wang, Bayesian spline learning for equation discovery of nonlinear dynamics with quantified uncertainty, in: NeurIPS, PMLR, 2022.
  • [10] D. Kochkov, J. A. Smith, A. Alieva, Q. Wang, M. P. Brenner, S. Hoyer, Machine learning–accelerated computational fluid dynamics, Proceedings of the National Academy of Sciences 118 (21) (2021) e2101784118.
  • [11] X.-Y. Liu, M. Zhu, L. Lu, H. Sun, J.-X. Wang, Multi-resolution partial differential equations preserved learning framework for spatiotemporal dynamics, Communications Physics 7 (1) (2024) 31.
  • [12] X. Fan, J.-X. Wang, Differentiable hybrid neural modeling for fluid-structure interaction, Journal of Computational Physics 496 (2024) 112584.
  • [13] D. Akhare, T. Luo, J.-X. Wang, Physics-integrated neural differentiable (PiNDiff) model for composites manufacturing, Computer Methods in Applied Mechanics and Engineering 406 (2023) 115902.
  • [14] D. Akhare, Z. Chen, R. Gulotty, T. Luo, J.-X. Wang, Probabilistic physics-integrated neural differentiable modeling for isothermal chemical vapor infiltration process, npj Computational Materials 10 (1) (2024) 120.
  • [15] D. Akhare, T. Luo, J.-X. Wang, Diffhybrid-uq: uncertainty quantification for differentiable hybrid neural modeling, arXiv preprint arXiv:2401.00161 (2023).
  • [16] X. Fan, D. Akhare, J.-X. Wang, Neural differentiable modeling with diffusion-based super-resolution for two-dimensional spatiotemporal turbulence, arXiv preprint arXiv:2406.20047 (2024).
  • [17] J. Tompson, K. Schlachter, P. Sprechmann, K. Perlin, Accelerating eulerian fluid simulation with convolutional networks, in: International Conference on Machine Learning, PMLR, 2017, pp. 3424–3433.
  • [18] R. Vinuesa, S. L. Brunton, Enhancing computational fluid dynamics with machine learning, Nature Computational Science 2 (6) (2022) 358–366.
  • [19] N. Margenberg, R. Jendersie, C. Lessig, T. Richter, Dnn-mg: A hybrid neural network/finite element method with applications to 3d simulations of the navier–stokes equations, Computer Methods in Applied Mechanics and Engineering 420 (2024) 116692.
  • [20] J. Wang, J. Wu, H. Xiao, A physics-informed machine learning approach of improving rans predicted reynolds stresses, in: 55th AIAA aerospace sciences meeting, 2017, p. 1712.
  • [21] K. Duraisamy, G. Iaccarino, H. Xiao, Turbulence modeling in the age of data, Annual Review of Fluid Mechanics 51 (2019) 357–377.
  • [22] J.-X. Wang, J. Huang, L. Duan, H. Xiao, Prediction of reynolds stresses in high-mach-number turbulent boundary layers using physics-informed machine learning, Theoretical and Computational Fluid Dynamics 33 (1) (2019) 1–19.
  • [23] L. Zanna, T. Bolton, Data-driven equation discovery of ocean mesoscale closures, Geophysical Research Letters 47 (17) (2020) e2020GL088376.
  • [24] M. A. Mendez, A. Ianiro, B. R. Noack, S. L. Brunton, Data-Driven Fluid Mechanics: Combining First Principles and Machine Learning, Cambridge University Press, 2023.
  • [25] R. Newbury, J. Collins, K. He, J. Pan, I. Posner, D. Howard, A. Cosgun, A review of differentiable simulators, IEEE Access (2024).
  • [26] B. List, L.-W. Chen, K. Bali, N. Thuerey, Differentiability in unrolled training of neural physics simulators on transient dynamics, Computer Methods in Applied Mechanics and Engineering 433 (2025) 117441.
  • [27] A. M. Schweidtmann, D. Zhang, M. von Stosch, A review and perspective on hybrid modelling methodologies, Digital Chemical Engineering (2023) 100136.
  • [28] M. Innes, A. Edelman, K. Fischer, C. Rackauckas, E. Saba, V. B. Shah, W. Tebbutt, A differentiable programming system to bridge machine learning and scientific computing, arXiv preprint arXiv:1907.07587 (2019).
  • [29] F. D. A. Belbute-Peres, T. Economon, Z. Kolter, Combining differentiable pde solvers and graph neural networks for fluid flow prediction, in: international conference on machine learning, PMLR, 2020, pp. 2402–2411.
  • [30] D. Z. Huang, K. Xu, C. Farhat, E. Darve, Learning constitutive relations from indirect observations using deep neural networks, Journal of Computational Physics 416 (2020) 109491.
  • [31] B. List, L.-W. Chen, N. Thuerey, Learned turbulence modelling with differentiable fluid solvers: physics-based loss functions and optimisation horizons, Journal of Fluid Mechanics 949 (2022) A25.
  • [32] X. Fan, J.-X. Wang, Differentiable hybrid neural modeling for fluid-structure interaction, Journal of Computational Physics 496 (2024) 112584.
  • [33] C. C. Margossian, A review of automatic differentiation and its efficient implementation, Wiley interdisciplinary reviews: data mining and knowledge discovery 9 (4) (2019) e1305.
  • [34] A. G. Baydin, B. A. Pearlmutter, A. A. Radul, J. M. Siskind, Automatic differentiation in machine learning: a survey, Journal of machine learning research 18 (153) (2018) 1–43.
  • [35] M. Blondel, V. Roulet, The elements of differentiable programming, arXiv preprint arXiv:2403.14606 (2024).
  • [36] S. Bai, J. Z. Kolter, V. Koltun, Deep equilibrium models, Advances in neural information processing systems 32 (2019).
  • [37] J. Pan, J. H. Liew, V. Y. Tan, J. Feng, H. Yan, Adjointdpm: Adjoint sensitivity method for gradient backpropagation of diffusion probabilistic models, arXiv preprint arXiv:2307.10711 (2023).
  • [38] T. Matsubara, Y. Miyatake, T. Yaguchi, The symplectic adjoint method: Memory-efficient backpropagation of neural-network-based differential equations, IEEE Transactions on Neural Networks and Learning Systems (2023).
  • [39] K. Fidkowski, Adjoint-based adaptive training of deep neural networks, in: AIAA AVIATION 2021 FORUM, 2021, p. 2904.
  • [40] R. T. Chen, Y. Rubanova, J. Bettencourt, D. K. Duvenaud, Neural ordinary differential equations, Advances in neural information processing systems 31 (2018).
  • [41] Y. Xie, T. Takikawa, S. Saito, O. Litany, S. Yan, N. Khan, F. Tombari, J. Tompkin, V. Sitzmann, S. Sridhar, Neural fields in visual computing and beyond, in: Computer Graphics Forum, Vol. 41, Wiley Online Library, 2022, pp. 641–676.
  • [42] V. Sitzmann, J. Martel, A. Bergman, D. Lindell, G. Wetzstein, Implicit neural representations with periodic activation functions, Advances in neural information processing systems 33 (2020) 7462–7473.
  • [43] P. Du, M. H. Parikh, X. Fan, X.-Y. Liu, J.-X. Wang, Conditional neural field latent diffusion model for generating spatiotemporal turbulence, Nature Communications 15 (1) (2024) 10416.

Appendix A Backpropagation for Autoregressive model

Refer to caption
Figure 14: Forward and Back-propagation for Autoregressive model

Consider a one-step function 𝚽t+1=f𝜽(𝚽t)subscript𝚽𝑡1subscript𝑓𝜽subscript𝚽𝑡\boldsymbol{\Phi}_{t+1}=f_{\boldsymbol{\theta}}(\boldsymbol{\Phi}_{t})bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) with 𝚽tnsubscript𝚽𝑡superscript𝑛\boldsymbol{\Phi}_{t}\in\mathbb{R}^{n}bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT called at every time step to generated the temporal dynamics. Ex-PiNDiff model’s forward computation to generate temporal dynamics up to time T𝑇Titalic_T, can be expressed as:

𝚽T=f𝜽(𝚽T1)=f𝜽(f𝜽(𝚽T2))==f𝜽(f𝜽(f𝜽(𝚽0)))subscript𝚽𝑇subscript𝑓𝜽subscript𝚽𝑇1subscript𝑓𝜽subscript𝑓𝜽subscript𝚽𝑇2subscript𝑓𝜽subscript𝑓𝜽subscript𝑓𝜽subscript𝚽0\boldsymbol{\Phi}_{T}=f_{\boldsymbol{\theta}}(\boldsymbol{\Phi}_{T-1})=f_{% \boldsymbol{\theta}}(f_{\boldsymbol{\theta}}(\boldsymbol{\Phi}_{T-2}))=\cdots=% f_{\boldsymbol{\theta}}(\cdots f_{\boldsymbol{\theta}}(f_{\boldsymbol{\theta}}% (\boldsymbol{\Phi}_{0})))bold_Φ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_T - 1 end_POSTSUBSCRIPT ) = italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_T - 2 end_POSTSUBSCRIPT ) ) = ⋯ = italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( ⋯ italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_Φ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) ) (21)

or

𝚽T=f𝜽[T]f𝜽[T1]f𝜽[T2]f𝜽[1](𝚽0).subscript𝚽𝑇subscriptsuperscript𝑓delimited-[]𝑇𝜽subscriptsuperscript𝑓delimited-[]𝑇1𝜽subscriptsuperscript𝑓delimited-[]𝑇2𝜽subscriptsuperscript𝑓delimited-[]1𝜽subscript𝚽0\boldsymbol{\Phi}_{T}=f^{[T]}_{\boldsymbol{\theta}}\cdot f^{[T-1]}_{% \boldsymbol{\theta}}\cdot f^{[T-2]}_{\boldsymbol{\theta}}\cdots f^{[1]}_{% \boldsymbol{\theta}}(\boldsymbol{\Phi}_{0}).bold_Φ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT = italic_f start_POSTSUPERSCRIPT [ italic_T ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ⋅ italic_f start_POSTSUPERSCRIPT [ italic_T - 1 ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ⋅ italic_f start_POSTSUPERSCRIPT [ italic_T - 2 ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ⋯ italic_f start_POSTSUPERSCRIPT [ 1 ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( bold_Φ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) . (22)

Let’s represent the forward trajectory of states as

𝚽0𝚽1𝚽t𝚽t+1𝚽T1𝚽TL.subscript𝚽0subscript𝚽1subscript𝚽𝑡subscript𝚽𝑡1subscript𝚽𝑇1subscript𝚽𝑇𝐿\boldsymbol{\Phi}_{0}\rightarrow\boldsymbol{\Phi}_{1}\cdots\boldsymbol{\Phi}_{% t}\rightarrow\boldsymbol{\Phi}_{t+1}\cdots\boldsymbol{\Phi}_{T-1}\rightarrow% \boldsymbol{\Phi}_{T}\rightarrow L.bold_Φ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT → bold_Φ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⋯ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT → bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ⋯ bold_Φ start_POSTSUBSCRIPT italic_T - 1 end_POSTSUBSCRIPT → bold_Φ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT → italic_L . (23)

The total loss L(𝜽)𝐿𝜽L(\boldsymbol{\theta})italic_L ( bold_italic_θ ) is defined over the entire rollout trajectory of T𝑇Titalic_T steps as

L(𝜽)=t=0Tt(𝚽t;𝒟t)+regulate(𝜽),𝐿𝜽superscriptsubscript𝑡0𝑇subscript𝑡subscript𝚽𝑡subscript𝒟𝑡subscript𝑟𝑒𝑔𝑢𝑙𝑎𝑡𝑒𝜽L(\boldsymbol{\theta})=\sum_{t=0}^{T}\mathcal{L}_{t}(\boldsymbol{\Phi}_{t};% \mathcal{D}_{t})+\mathcal{L}_{regulate}(\boldsymbol{\theta}),italic_L ( bold_italic_θ ) = ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; caligraphic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + caligraphic_L start_POSTSUBSCRIPT italic_r italic_e italic_g italic_u italic_l italic_a italic_t italic_e end_POSTSUBSCRIPT ( bold_italic_θ ) , (24)

and the gradient of loss with respect to 𝜽𝜽\boldsymbol{\theta}bold_italic_θ can be computed using the chain rule,

dLd𝜽=L𝜽+t=0TL𝜽|𝚽t,𝑑𝐿𝑑𝜽𝐿𝜽evaluated-atsuperscriptsubscript𝑡0𝑇𝐿𝜽subscript𝚽𝑡\frac{dL}{d\boldsymbol{\theta}}=\frac{\partial L}{\partial\boldsymbol{\theta}}% +\sum_{t=0}^{T}\frac{\partial L}{\partial\boldsymbol{\theta}}\bigg{|}_{% \boldsymbol{\Phi}_{t}},divide start_ARG italic_d italic_L end_ARG start_ARG italic_d bold_italic_θ end_ARG = divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_italic_θ end_ARG + ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_italic_θ end_ARG | start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT , (25)

where

L𝜽|𝚽t=L𝚽t+1[𝚽t+1𝜽]=L𝚽t+1[f𝜽],evaluated-at𝐿𝜽subscript𝚽𝑡𝐿subscript𝚽𝑡1delimited-[]subscript𝚽𝑡1𝜽𝐿subscript𝚽𝑡1delimited-[]𝑓𝜽\displaystyle\frac{\partial L}{\partial\boldsymbol{\theta}}\bigg{|}_{% \boldsymbol{\Phi}_{t}}=\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}}\cdot% \Bigg{[}\frac{\partial\boldsymbol{\Phi}_{t+1}}{\partial\boldsymbol{\theta}}% \Bigg{]}=\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}}\cdot\Bigg{[}\frac{% \partial f}{\partial\boldsymbol{\theta}}\Bigg{]},divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_italic_θ end_ARG | start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT = divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT end_ARG ⋅ [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG ] = divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT end_ARG ⋅ [ divide start_ARG ∂ italic_f end_ARG start_ARG ∂ bold_italic_θ end_ARG ] , (26a)
L𝚽t=t𝚽t+L𝚽t+1[𝚽t+1𝚽t]=t𝚽t+L𝚽t+1[f𝚽t],𝐿subscript𝚽𝑡subscript𝑡subscript𝚽𝑡𝐿subscript𝚽𝑡1delimited-[]subscript𝚽𝑡1subscript𝚽𝑡subscript𝑡subscript𝚽𝑡𝐿subscript𝚽𝑡1delimited-[]𝑓subscript𝚽𝑡\displaystyle\frac{\partial L}{\partial\boldsymbol{\Phi}_{t}}=\frac{\partial% \mathcal{L}_{t}}{\partial\boldsymbol{\Phi}_{t}}+\frac{\partial L}{\partial% \boldsymbol{\Phi}_{t+1}}\cdot\Bigg{[}\frac{\partial\boldsymbol{\Phi}_{t+1}}{% \partial\boldsymbol{\Phi}_{t}}\Bigg{]}=\frac{\partial\mathcal{L}_{t}}{\partial% \boldsymbol{\Phi}_{t}}+\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}}\cdot% \Bigg{[}\frac{\partial f}{\partial\boldsymbol{\Phi}_{t}}\Bigg{]},divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG = divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG + divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT end_ARG ⋅ [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ] = divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG + divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT end_ARG ⋅ [ divide start_ARG ∂ italic_f end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ] , (26b)

Here L𝚽t+1[f𝜽]𝐿subscript𝚽𝑡1delimited-[]𝑓𝜽\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}}\Big{[}\frac{\partial f}{% \partial\boldsymbol{\theta}}\Big{]}divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT end_ARG [ divide start_ARG ∂ italic_f end_ARG start_ARG ∂ bold_italic_θ end_ARG ] and L𝚽t+1[f𝚽t]𝐿subscript𝚽𝑡1delimited-[]𝑓subscript𝚽𝑡\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}}\Big{[}\frac{\partial f}{% \partial\boldsymbol{\Phi}_{t}}\Big{]}divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT end_ARG [ divide start_ARG ∂ italic_f end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ] are VJPs, and []delimited-[][\cdot][ ⋅ ] represent Jacobian matrix.

Appendix B Gradient backpropagation for implicit-PiNDiff with naive AD

In case of Im-PiNDiff, we employ a RootFind()RootFind\text{RootFind}(\cdot)RootFind ( ⋅ ) method to find the state by finding solution for f𝜽im(𝚽t;𝚽t1)=0superscriptsubscript𝑓𝜽𝑖𝑚subscript𝚽𝑡superscriptsubscript𝚽𝑡10f_{\boldsymbol{\theta}}^{im}\big{(}\boldsymbol{\Phi}_{t};\boldsymbol{\Phi}_{t-% 1}^{*}\big{)}=0italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_Φ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) = 0 at each time step, which can be expressed as

𝚽tRootFind(f𝜽im(𝚽t;𝚽t1)=0),t=1,,T.formulae-sequencesuperscriptsubscript𝚽𝑡RootFindsuperscriptsubscript𝑓𝜽𝑖𝑚subscript𝚽𝑡superscriptsubscript𝚽𝑡10𝑡1𝑇\boldsymbol{\Phi}_{t}^{*}\leftarrow\text{RootFind}\Big{(}f_{\boldsymbol{\theta% }}^{im}\big{(}\boldsymbol{\Phi}_{t};\boldsymbol{\Phi}_{t-1}^{*}\big{)}=0\Big{)% },\quad t=1,\cdots,T.bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ← RootFind ( italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_Φ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) = 0 ) , italic_t = 1 , ⋯ , italic_T . (27)

or

𝚽TRootFind(f𝜽im)RootFind(f𝜽im)RootFind(f𝜽im(𝚽1;𝚽0)=0).superscriptsubscript𝚽𝑇RootFindsuperscriptsubscript𝑓𝜽𝑖𝑚RootFindsuperscriptsubscript𝑓𝜽𝑖𝑚RootFindsuperscriptsubscript𝑓𝜽𝑖𝑚subscript𝚽1subscript𝚽00\boldsymbol{\Phi}_{T}^{*}\leftarrow\text{RootFind}(f_{\boldsymbol{\theta}}^{im% })\cdot\text{RootFind}(f_{\boldsymbol{\theta}}^{im})\cdots\text{RootFind}\Big{% (}f_{\boldsymbol{\theta}}^{im}\big{(}\boldsymbol{\Phi}_{1};\boldsymbol{\Phi}_{% 0}\big{)}=0\Big{)}.bold_Φ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ← RootFind ( italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ) ⋅ RootFind ( italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ) ⋯ RootFind ( italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( bold_Φ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ; bold_Φ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = 0 ) . (28)

The iterative solver RootFind()RootFind\text{RootFind}(\cdot)RootFind ( ⋅ ) creates a sequence of guesses

{𝚽t[0],𝚽t[1],,𝚽t[K]}={RootFind[k](f𝜽im(𝚽t[0];𝚽t1)=0)}k=0Ksuperscriptsubscript𝚽𝑡delimited-[]0superscriptsubscript𝚽𝑡delimited-[]1superscriptsubscript𝚽𝑡delimited-[]𝐾superscriptsubscriptsuperscriptRootFinddelimited-[]𝑘superscriptsubscript𝑓𝜽𝑖𝑚superscriptsubscript𝚽𝑡delimited-[]0subscript𝚽𝑡10𝑘0𝐾\{\boldsymbol{\Phi}_{t}^{[0]},\boldsymbol{\Phi}_{t}^{[1]},\cdots,\boldsymbol{% \Phi}_{t}^{[K]}\}=\Big{\{}\text{RootFind}^{[k]}\Big{(}f_{\boldsymbol{\theta}}^% {im}\big{(}\boldsymbol{\Phi}_{t}^{[0]};\boldsymbol{\Phi}_{t-1}\big{)}=0\Big{)}% \Big{\}}_{k=0}^{K}{ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 0 ] end_POSTSUPERSCRIPT , bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 1 ] end_POSTSUPERSCRIPT , ⋯ , bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ italic_K ] end_POSTSUPERSCRIPT } = { RootFind start_POSTSUPERSCRIPT [ italic_k ] end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 0 ] end_POSTSUPERSCRIPT ; bold_Φ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ) = 0 ) } start_POSTSUBSCRIPT italic_k = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT (29)

that converge to final solution satisfying f𝜽im(𝚽t[K];𝚽t1)=0superscriptsubscript𝑓𝜽𝑖𝑚superscriptsubscript𝚽𝑡delimited-[]𝐾subscript𝚽𝑡10f_{\boldsymbol{\theta}}^{im}\big{(}\boldsymbol{\Phi}_{t}^{[K]};\boldsymbol{% \Phi}_{t-1}\big{)}=0italic_f start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ italic_K ] end_POSTSUPERSCRIPT ; bold_Φ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ) = 0. For Naive AD, we can consider a fixed number of iterations K𝐾Kitalic_K for RootFind()RootFind\text{RootFind}(\cdot)RootFind ( ⋅ ) algorithm. Here, the forward propagation of state looks like

𝚽0{𝚽1[0]𝚽1[K]}𝚽1{𝚽t[0]𝚽t[K]}𝚽t{𝚽T[0]𝚽T[K]}𝚽Tsubscript𝚽0superscriptsubscript𝚽1delimited-[]0superscriptsubscript𝚽1delimited-[]𝐾superscriptsubscript𝚽1superscriptsubscript𝚽𝑡delimited-[]0superscriptsubscript𝚽𝑡delimited-[]𝐾superscriptsubscript𝚽𝑡superscriptsubscript𝚽𝑇delimited-[]0superscriptsubscript𝚽𝑇delimited-[]𝐾superscriptsubscript𝚽𝑇\boldsymbol{\Phi}_{0}\rightarrow\big{\{}\boldsymbol{\Phi}_{1}^{[0]}\cdots% \boldsymbol{\Phi}_{1}^{[K]}\big{\}}\rightarrow\boldsymbol{\Phi}_{1}^{*}\cdots% \big{\{}\boldsymbol{\Phi}_{t}^{[0]}\cdots\boldsymbol{\Phi}_{t}^{[K]}\big{\}}% \rightarrow\boldsymbol{\Phi}_{t}^{*}\cdots\big{\{}\boldsymbol{\Phi}_{T}^{[0]}% \cdots\boldsymbol{\Phi}_{T}^{[K]}\big{\}}\rightarrow\boldsymbol{\Phi}_{T}^{*}bold_Φ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT → { bold_Φ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 0 ] end_POSTSUPERSCRIPT ⋯ bold_Φ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ italic_K ] end_POSTSUPERSCRIPT } → bold_Φ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ⋯ { bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 0 ] end_POSTSUPERSCRIPT ⋯ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ italic_K ] end_POSTSUPERSCRIPT } → bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ⋯ { bold_Φ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 0 ] end_POSTSUPERSCRIPT ⋯ bold_Φ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ italic_K ] end_POSTSUPERSCRIPT } → bold_Φ start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT (30)

and the total loss L(𝜽)𝐿𝜽L(\boldsymbol{\theta})italic_L ( bold_italic_θ ) defined over the entire rollout trajectory of T𝑇Titalic_T steps will be,

L(𝜽)=t=0Tt(𝚽t;𝒟t)+regulate(𝜽).𝐿𝜽superscriptsubscript𝑡0𝑇subscript𝑡superscriptsubscript𝚽𝑡subscript𝒟𝑡subscript𝑟𝑒𝑔𝑢𝑙𝑎𝑡𝑒𝜽L(\boldsymbol{\theta})=\sum_{t=0}^{T}\mathcal{L}_{t}(\boldsymbol{\Phi}_{t}^{*}% ;\mathcal{D}_{t})+\mathcal{L}_{regulate}(\boldsymbol{\theta}).italic_L ( bold_italic_θ ) = ∑ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ; caligraphic_D start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) + caligraphic_L start_POSTSUBSCRIPT italic_r italic_e italic_g italic_u italic_l italic_a italic_t italic_e end_POSTSUBSCRIPT ( bold_italic_θ ) . (31)

The gradient of loss with respect to 𝜽𝜽\boldsymbol{\theta}bold_italic_θ for naive Im-PiNDiff is

dLd𝜽=L𝜽+t=1TL𝜽|𝚽t=L𝜽+t=1Tk=0KL𝜽|𝚽t[k],𝑑𝐿𝑑𝜽𝐿𝜽evaluated-atsuperscriptsubscript𝑡1𝑇𝐿𝜽superscriptsubscript𝚽𝑡𝐿𝜽evaluated-atsuperscriptsubscript𝑡1𝑇superscriptsubscript𝑘0𝐾𝐿𝜽superscriptsubscript𝚽𝑡delimited-[]𝑘\frac{dL}{d\boldsymbol{\theta}}=\frac{\partial L}{\partial\boldsymbol{\theta}}% +\sum_{t=1}^{T}\frac{\partial L}{\partial\boldsymbol{\theta}}\bigg{|}_{% \boldsymbol{\Phi}_{t}^{*}}=\frac{\partial L}{\partial\boldsymbol{\theta}}+\sum% _{t=1}^{T}\sum_{k=0}^{K}\frac{\partial L}{\partial\boldsymbol{\theta}}\bigg{|}% _{\boldsymbol{\Phi}_{t}^{[k]}},divide start_ARG italic_d italic_L end_ARG start_ARG italic_d bold_italic_θ end_ARG = divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_italic_θ end_ARG + ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_italic_θ end_ARG | start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_italic_θ end_ARG + ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_italic_θ end_ARG | start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ italic_k ] end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , (32)

where the AD will store every intermediate iterates {𝚽t[0]𝚽t[K]}superscriptsubscript𝚽𝑡delimited-[]0superscriptsubscript𝚽𝑡delimited-[]𝐾\big{\{}\boldsymbol{\Phi}_{t}^{[0]}\cdots\boldsymbol{\Phi}_{t}^{[K]}\big{\}}{ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 0 ] end_POSTSUPERSCRIPT ⋯ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ italic_K ] end_POSTSUPERSCRIPT } at every time step in addition to storing 𝚽tsuperscriptsubscript𝚽𝑡\boldsymbol{\Phi}_{t}^{*}bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT at every time step, resulting in substantial memory overhead.

Appendix C Derivation of adjoint-based VJP for implicit layers

To save memory, we need to compute L𝜽|𝚽tevaluated-at𝐿𝜽superscriptsubscript𝚽𝑡\frac{\partial L}{\partial\boldsymbol{\theta}}\big{|}_{\boldsymbol{\Phi}_{t}^{% *}}divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_italic_θ end_ARG | start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT without needing to save the intermediate iterates {𝚽t[0]𝚽t[K]}superscriptsubscript𝚽𝑡delimited-[]0superscriptsubscript𝚽𝑡delimited-[]𝐾\big{\{}\boldsymbol{\Phi}_{t}^{[0]}\cdots\boldsymbol{\Phi}_{t}^{[K]}\big{\}}{ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ 0 ] end_POSTSUPERSCRIPT ⋯ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT [ italic_K ] end_POSTSUPERSCRIPT }. Recall L𝜽|𝚽t=L𝚽t+1[𝚽t+1𝜽],L𝚽t=t𝚽t+L𝚽t+1[𝚽t+1𝚽t]formulae-sequenceevaluated-at𝐿𝜽superscriptsubscript𝚽𝑡𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1𝜽𝐿superscriptsubscript𝚽𝑡subscript𝑡superscriptsubscript𝚽𝑡𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1superscriptsubscript𝚽𝑡\frac{\partial L}{\partial\boldsymbol{\theta}}\big{|}_{\boldsymbol{\Phi}_{t}^{% *}}=\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\cdot\Big{[}\frac{% \partial\boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{\theta}}\Big{]},\frac% {\partial L}{\partial\boldsymbol{\Phi}_{t}^{*}}=\frac{\partial\mathcal{L}_{t}}% {\partial\boldsymbol{\Phi}_{t}^{*}}+\frac{\partial L}{\partial\boldsymbol{\Phi% }_{t+1}^{*}}\cdot\Big{[}\frac{\partial\boldsymbol{\Phi}_{t+1}^{*}}{\partial% \boldsymbol{\Phi}_{t}^{*}}\Big{]}divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_italic_θ end_ARG | start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG ] , divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG = divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG + divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ]. So we need to compute VJPs: L𝚽t+1[𝚽t+1𝜽]𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1𝜽\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\Big{[}\frac{\partial% \boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{\theta}}\Big{]}divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG ] and L𝚽t+1[𝚽t+1𝚽t]𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1superscriptsubscript𝚽𝑡\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\Big{[}\frac{\partial% \boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{\Phi}_{t}^{*}}\Big{]}divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ] for the implicit layer. The joint-state method provides a way to compute these VJPs as a solution of the linear system. The derivation of the linear equation for obtaining VJPs is obtained using the implicit function theorem.

Implicit function theorem:[35]

For a function f(𝝎,𝝀):W×ΛW:𝑓𝝎𝝀𝑊Λ𝑊f(\boldsymbol{\omega},\boldsymbol{\lambda}):W\times\Lambda\rightarrow Witalic_f ( bold_italic_ω , bold_italic_λ ) : italic_W × roman_Λ → italic_W, that is continuously differentiable function in a neighborhood of (𝝎0,𝝀0)subscript𝝎0subscript𝝀0(\boldsymbol{\omega}_{0},\boldsymbol{\lambda}_{0})( bold_italic_ω start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , bold_italic_λ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) with f(𝝎,𝝀)=0𝑓𝝎𝝀0f(\boldsymbol{\omega},\boldsymbol{\lambda})=\textbf{0}italic_f ( bold_italic_ω , bold_italic_λ ) = 0 and 𝝎f(𝝎,𝝀)𝝎𝑓𝝎𝝀\frac{\partial}{\partial\boldsymbol{\omega}}f(\boldsymbol{\omega},\boldsymbol{% \lambda})divide start_ARG ∂ end_ARG start_ARG ∂ bold_italic_ω end_ARG italic_f ( bold_italic_ω , bold_italic_λ ) is invertible, then there exists a neighbourhood of 𝝀0subscript𝝀0\boldsymbol{\lambda}_{0}bold_italic_λ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT in which there is a function 𝝎(𝝀)superscript𝝎𝝀\boldsymbol{\omega}^{*}(\boldsymbol{\lambda})bold_italic_ω start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_italic_λ ) such that

  • 𝝎(𝝀)=𝝎0superscript𝝎𝝀subscript𝝎0\boldsymbol{\omega}^{*}(\boldsymbol{\lambda})=\boldsymbol{\omega}_{0}bold_italic_ω start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_italic_λ ) = bold_italic_ω start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT

  • f(𝝎(𝝀),𝝀)=0𝑓superscript𝝎𝝀𝝀0f(\boldsymbol{\omega}^{*}(\boldsymbol{\lambda}),\boldsymbol{\lambda})=0italic_f ( bold_italic_ω start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_italic_λ ) , bold_italic_λ ) = 0 for all 𝝀𝝀\boldsymbol{\lambda}bold_italic_λ in the neighborhood,

  • 𝝀𝝎(𝝀)=[𝝎f(𝝎(𝝀),𝝀)]1𝝀f(𝝎(𝝀),𝝀)𝝀superscript𝝎𝝀superscriptdelimited-[]𝝎𝑓superscript𝝎𝝀𝝀1𝝀𝑓superscript𝝎𝝀𝝀\frac{\partial}{\partial\boldsymbol{\lambda}}\boldsymbol{\omega}^{*}(% \boldsymbol{\lambda})=-\bigg{[}\frac{\partial}{\partial\boldsymbol{\omega}}f(% \boldsymbol{\omega}^{*}(\boldsymbol{\lambda}),\boldsymbol{\lambda})\bigg{]}^{-% 1}\frac{\partial}{\partial\boldsymbol{\lambda}}f(\boldsymbol{\omega}^{*}(% \boldsymbol{\lambda}),\boldsymbol{\lambda})divide start_ARG ∂ end_ARG start_ARG ∂ bold_italic_λ end_ARG bold_italic_ω start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_italic_λ ) = - [ divide start_ARG ∂ end_ARG start_ARG ∂ bold_italic_ω end_ARG italic_f ( bold_italic_ω start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_italic_λ ) , bold_italic_λ ) ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT divide start_ARG ∂ end_ARG start_ARG ∂ bold_italic_λ end_ARG italic_f ( bold_italic_ω start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( bold_italic_λ ) , bold_italic_λ )

Conisder fim(𝚽t;𝚽t1,𝜽)superscript𝑓𝑖𝑚subscript𝚽𝑡superscriptsubscript𝚽𝑡1𝜽f^{im}\big{(}\boldsymbol{\Phi}_{t};\boldsymbol{\Phi}_{t-1}^{*},\boldsymbol{% \theta}\big{)}italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT ( bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_Φ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_italic_θ ) with 𝚽t1,𝜽superscriptsubscript𝚽𝑡1𝜽\boldsymbol{\Phi}_{t-1}^{*},\boldsymbol{\theta}bold_Φ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , bold_italic_θ as 𝝀𝝀\boldsymbol{\lambda}bold_italic_λ, the implicit function theorem provides

𝚽t+1𝜽=[fim𝚽t+1]1fim𝜽,superscriptsubscript𝚽𝑡1𝜽superscriptdelimited-[]superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡11superscript𝑓𝑖𝑚𝜽\displaystyle\frac{\partial\boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{% \theta}}=-\left[\frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t+1}^{*}}% \right]^{-1}\cdot\frac{\partial f^{im}}{\partial\boldsymbol{\theta}},divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG = - [ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⋅ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG , (33a)
𝚽t+1𝚽t=[fim𝚽t+1]1fim𝚽t.superscriptsubscript𝚽𝑡1superscriptsubscript𝚽𝑡superscriptdelimited-[]superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡11superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡\displaystyle\frac{\partial\boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{% \Phi}_{t}^{*}}=-\left[\frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t+1}^{% *}}\right]^{-1}\cdot\frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t}^{*}}.divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG = - [ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⋅ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG . (33b)

Multilpy row vector L𝚽t+1𝐿superscriptsubscript𝚽𝑡1\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG on left side, we get

L𝚽t+1[𝚽t+1𝜽]=L𝚽t+1[fim𝚽t+1]1fim𝜽,𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1𝜽𝐿superscriptsubscript𝚽𝑡1superscriptdelimited-[]superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡11superscript𝑓𝑖𝑚𝜽\displaystyle\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\cdot\Bigg{% [}\frac{\partial\boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{\theta}}\Bigg% {]}=-\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\cdot\left[\frac{% \partial f^{im}}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\right]^{-1}\cdot\frac{% \partial f^{im}}{\partial\boldsymbol{\theta}},divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG ] = - divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ [ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⋅ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG , (34a)
L𝚽t+1[𝚽t+1𝚽t]=L𝚽t+1[fim𝚽t+1]1fim𝚽t.𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1superscriptsubscript𝚽𝑡𝐿superscriptsubscript𝚽𝑡1superscriptdelimited-[]superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡11superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡\displaystyle\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\cdot\Bigg{% [}\frac{\partial\boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{\Phi}_{t}^{*}% }\Bigg{]}=-\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\cdot\left[% \frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\right]^{-1}\cdot% \frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t}^{*}}.divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ] = - divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ [ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ⋅ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG . (34b)

Now we define the adjoint vector 𝐰t+1T1×nsubscriptsuperscript𝐰𝑇𝑡1superscript1𝑛\mathbf{w}^{T}_{t+1}\in\mathbb{R}^{1\times n}bold_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_n end_POSTSUPERSCRIPT as,

𝐰T=L𝚽t+1[fim𝚽t+1]1,superscript𝐰𝑇𝐿superscriptsubscript𝚽𝑡1superscriptdelimited-[]superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡11\mathbf{w}^{T}=-\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\cdot% \left[\frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\right]^{-1},bold_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT = - divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ [ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT , (35)

which is obtained by solving the linear system,

wT[fim𝚽t+1]=L𝚽t+1.superscriptw𝑇delimited-[]superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡1𝐿superscriptsubscript𝚽𝑡1\textbf{w}^{T}\bigg{[}\frac{\partial f^{im}}{\partial\boldsymbol{\Phi}_{t+1}^{% *}}\bigg{]}=-\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}.w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT [ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ] = - divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG . (36)

This linear equation is solved efficiently using iterative numerical linear solvers (such as GMRES or conjugate gradient methods) Finally, the VJPs are expressed as,

L𝚽t+1[𝚽t+1𝜽]=𝐰t+1Tfim𝜽,𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1𝜽subscriptsuperscript𝐰𝑇𝑡1superscript𝑓𝑖𝑚𝜽\displaystyle\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\cdot\Bigg{% [}\frac{\partial\boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{\theta}}\Bigg% {]}=\mathbf{w}^{T}_{t+1}\cdot\frac{\partial f^{im}}{\partial\boldsymbol{\theta% }},divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG ] = bold_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ⋅ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG , (37a)
L𝚽t+1[𝚽t+1𝚽t]=𝐰t+1Tfim𝚽t,𝐿superscriptsubscript𝚽𝑡1delimited-[]superscriptsubscript𝚽𝑡1superscriptsubscript𝚽𝑡subscriptsuperscript𝐰𝑇𝑡1superscript𝑓𝑖𝑚superscriptsubscript𝚽𝑡\displaystyle\frac{\partial L}{\partial\boldsymbol{\Phi}_{t+1}^{*}}\cdot\Bigg{% [}\frac{\partial\boldsymbol{\Phi}_{t+1}^{*}}{\partial\boldsymbol{\Phi}_{t}^{*}% }\Bigg{]}=\mathbf{w}^{T}_{t+1}\cdot\frac{\partial f^{im}}{\partial\boldsymbol{% \Phi}_{t}^{*}},divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ⋅ [ divide start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ] = bold_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ⋅ divide start_ARG ∂ italic_f start_POSTSUPERSCRIPT italic_i italic_m end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG , (37b)

We can compute L𝜽|𝚽tevaluated-at𝐿𝜽superscriptsubscript𝚽𝑡\frac{\partial L}{\partial\boldsymbol{\theta}}\big{|}_{\boldsymbol{\Phi}_{t}^{% *}}divide start_ARG ∂ italic_L end_ARG start_ARG ∂ bold_italic_θ end_ARG | start_POSTSUBSCRIPT bold_Φ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT for every implicit layer without needing to save the intermediate iterates, thereby reducing the memory requirement for the AD backpropagation.

Appendix D Spatiotemporal fields using a combination of cosine functions

In order to make inference challenging, a combination of cosine functions is superimposed to generate a more complex spatio-temporal field for advection and viscosity.

f(𝐱,t)=i=1nwAisin(2π(ki𝐱+ωit)+pi)𝑓𝐱𝑡superscriptsubscript𝑖1𝑛𝑤subscript𝐴𝑖2𝜋subscript𝑘𝑖𝐱subscript𝜔𝑖𝑡subscript𝑝𝑖f(\mathbf{x},t)=\sum_{i=1}^{nw}A_{i}\sin(2\pi(k_{i}\mathbf{x}+\omega_{i}t)+p_{% i})italic_f ( bold_x , italic_t ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n italic_w end_POSTSUPERSCRIPT italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_sin ( 2 italic_π ( italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_x + italic_ω start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_t ) + italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) (38)

where Aiki,ωi,pi𝒰(0,2)similar-tosubscript𝐴𝑖subscript𝑘𝑖subscript𝜔𝑖subscript𝑝𝑖𝒰02A_{i}k_{i},\omega_{i},p_{i}\sim\mathcal{U}(0,2)italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_ω start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ caligraphic_U ( 0 , 2 ) are randomly sampled. Ai𝒰(2,2),ki𝒰(0,2),ωi𝒰(1,1),pi𝒰(0,2)formulae-sequencesimilar-tosubscript𝐴𝑖𝒰22formulae-sequencesimilar-tosubscript𝑘𝑖𝒰02formulae-sequencesimilar-tosubscript𝜔𝑖𝒰11similar-tosubscript𝑝𝑖𝒰02A_{i}\sim\mathcal{U}(-2,2),k_{i}\sim\mathcal{U}(0,2),\omega_{i}\sim\mathcal{U}% (-1,1),p_{i}\sim\mathcal{U}(0,2)italic_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ caligraphic_U ( - 2 , 2 ) , italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ caligraphic_U ( 0 , 2 ) , italic_ω start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ caligraphic_U ( - 1 , 1 ) , italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ caligraphic_U ( 0 , 2 )

Appendix E PiNDiff modeling for chemical vapor infiltration

Here we talk in more detail about the reaction-diffusion system, chemical vapor infiltration (CVI) used in this study. CVI is a materials processing technique used to fabricate composite materials. In this method, a porous preform—usually made of fibers like carbon is exposed to a reactive gas mixture at elevated temperatures inside a furnace. The gaseous precursors diffuse into the pores of the preform and undergo chemical reactions, typically decomposition or reduction, to deposit a solid material onto the internal surfaces of the preform. Over time, this gradual deposition fills the pores and forms a dense matrix around the fibers without significantly disturbing the structure. CVI allows for precise control over material composition and microstructure, making it suitable for producing high-temperature, corrosion-resistant components used in aerospace, energy, and defense applications. A PiNDiff-CVI model based on foundational physics was developed to simulate the CVI[14] process, whose equations are given as

Deff2(C)=KSvC,subscript𝐷effsuperscript2𝐶𝐾subscript𝑆v𝐶D_{\text{eff}}\nabla^{2}(C)=KS_{\text{v}}C,italic_D start_POSTSUBSCRIPT eff end_POSTSUBSCRIPT ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_C ) = italic_K italic_S start_POSTSUBSCRIPT v end_POSTSUBSCRIPT italic_C , (39a)
ρsdεdt=qMsKSvC.subscript𝜌s𝑑𝜀𝑑𝑡𝑞subscript𝑀s𝐾subscript𝑆v𝐶\rho_{\text{s}}\frac{d\varepsilon}{dt}=-qM_{\text{s}}KS_{\text{v}}C.italic_ρ start_POSTSUBSCRIPT s end_POSTSUBSCRIPT divide start_ARG italic_d italic_ε end_ARG start_ARG italic_d italic_t end_ARG = - italic_q italic_M start_POSTSUBSCRIPT s end_POSTSUBSCRIPT italic_K italic_S start_POSTSUBSCRIPT v end_POSTSUBSCRIPT italic_C . (39b)

In the above equations, C=C(𝐱,t)𝐶𝐶𝐱𝑡C=C(\mathbf{x},t)italic_C = italic_C ( bold_x , italic_t ) denotes the effective molarity field (mol m-3) of all reactive gases, ε=ε(𝐱,t)𝜀𝜀𝐱𝑡\varepsilon=\varepsilon(\mathbf{x},t)italic_ε = italic_ε ( bold_x , italic_t ) is the porosity of the preform, q𝑞qitalic_q represent a constant stichometric coefficient, Mdsubscript𝑀𝑑M_{d}italic_M start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT is the molar mass (kg mol-1), and ρdsubscript𝜌𝑑\rho_{d}italic_ρ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT is the density (kg m-3) of the deposited solid (carbon or SiC). Deff=Deff(𝐱,t)subscript𝐷𝑒𝑓𝑓subscript𝐷𝑒𝑓𝑓𝐱𝑡D_{eff}=D_{eff}(\mathbf{x},t)italic_D start_POSTSUBSCRIPT italic_e italic_f italic_f end_POSTSUBSCRIPT = italic_D start_POSTSUBSCRIPT italic_e italic_f italic_f end_POSTSUBSCRIPT ( bold_x , italic_t ) represents the effective diffusion coefficient field, K=K(𝐱,t)𝐾𝐾𝐱𝑡K=K(\mathbf{x},t)italic_K = italic_K ( bold_x , italic_t ) is the deposition reaction rate, and Sv=Sv(𝐱,t)subscript𝑆𝑣subscript𝑆𝑣𝐱𝑡S_{v}=S_{v}(\mathbf{x},t)italic_S start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT = italic_S start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ( bold_x , italic_t ) corresponds to the surface-to-volume ratio.

Refer to caption
Figure 15: Foundational physics of the CVI process with neural operator approximations

The details regarding this solver are described by Akhare et al.[14]. In the PiNDiff-CVI model, the underlying transport and reaction functions were unknown and modeled as the operators Deffsubscript𝐷𝑒𝑓𝑓D_{eff}italic_D start_POSTSUBSCRIPT italic_e italic_f italic_f end_POSTSUBSCRIPT, K𝐾Kitalic_K, and Svsubscript𝑆𝑣S_{v}italic_S start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT using DNNs as shown in Fig. 15. Equation 39ba is an elliptic Poisson equation governing steady-state molarity distribution, solved iteratively at each time step using the point-Jacobi method (inner optimization). While Equation 39bb is a hyperbolic transport equation modeling time-evolving deposition dynamics, solved explicitly using the Euler time integration scheme. Previously, a naive approach involved using a fixed number of iterations for solving Equation 39ba was employed, necessary for constructing a static computational graph required for gradient calculations, resulting in large memory usage. By leveraging adjoint-based backpropagation, we eliminate the need to save the intermediate iterates, thereby reducing memory usage.