# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, Optional, Sequence
import einops
import einops.layers.torch
import torch
import torch.nn.functional as F
from nemo.collections.common.parts.utils import activation_registry, mask_sequence_tensor
from nemo.core.classes import NeuralModule, typecheck
from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType, VoidType
from nemo.utils import logging
[docs]
class SpectrogramNoiseConditionalScoreNetworkPlusPlus(NeuralModule):
"""This model handles complex-valued inputs by stacking real and imaginary components.
Stacked tensor is processed using NCSN++ and the output is projected to generate real
and imaginary components of the output channels.
Args:
in_channels: number of input complex-valued channels
out_channels: number of output complex-valued channels
"""
def __init__(self, *, in_channels: int = 1, out_channels: int = 1, **kwargs):
super().__init__()
# Number of input signals for this estimator
if in_channels < 1:
raise ValueError(
f'Number of input channels needs to be larger or equal to one, current value {in_channels}'
)
self.in_channels = in_channels
# Number of output signals for this estimator
if out_channels < 1:
raise ValueError(
f'Number of output channels needs to be larger or equal to one, current value {out_channels}'
)
self.out_channels = out_channels
# Instantiate noise conditional score network NCSN++
ncsnpp_params = kwargs.copy()
ncsnpp_params['in_channels'] = ncsnpp_params['out_channels'] = 2 * self.in_channels # stack real and imag
self.ncsnpp = NoiseConditionalScoreNetworkPlusPlus(**ncsnpp_params)
# Output projection to generate real and imaginary components of the output channels
self.output_projection = torch.nn.Conv2d(
in_channels=2 * self.in_channels, out_channels=2 * self.out_channels, kernel_size=1
)
logging.debug('Initialized %s with', self.__class__.__name__)
logging.debug('\tin_channels: %s', self.in_channels)
logging.debug('\tout_channels: %s', self.out_channels)
@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"input_length": NeuralType(('B',), LengthsType(), optional=True),
"condition": NeuralType(('B',), FloatType(), optional=True),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"output_length": NeuralType(('B',), LengthsType(), optional=True),
}
[docs]
@typecheck()
def forward(self, input, input_length=None, condition=None):
# Stack real and imaginary components
B, C_in, D, T = input.shape
if C_in != self.in_channels:
raise RuntimeError(f'Unexpected input channel size {C_in}, expected {self.in_channels}')
# Stack real and imaginary parts
input_real_imag = torch.stack([input.real, input.imag], dim=2)
input = einops.rearrange(input_real_imag, 'B C RI F T -> B (C RI) F T')
# Process using NCSN++
output, output_length = self.ncsnpp(input=input, input_length=input_length, condition=condition)
# Output projection
output = self.output_projection(output)
# Convert to complex-valued signal
output = output.reshape(B, 2, self.out_channels, D, T)
# Move real/imag dimension to the end
output = output.permute(0, 2, 3, 4, 1)
output = torch.view_as_complex(output.contiguous())
return output, output_length
[docs]
class NoiseConditionalScoreNetworkPlusPlus(NeuralModule):
"""Implementation of Noise Conditional Score Network (NCSN++) architecture.
References:
- Song et al., Score-Based Generative Modeling through Stochastic Differential Equations, NeurIPS 2021
- Brock et al., Large scale GAN training for high fidelity natural image synthesis, ICLR 2018
"""
def __init__(
self,
nonlinearity: str = "swish",
in_channels: int = 2, # number of channels in the input image
out_channels: int = 2, # number of channels in the output image
channels: Sequence[int] = (128, 128, 256, 256, 256), # number of channels at start + at every resolution
num_res_blocks: int = 2,
num_resolutions: int = 4,
init_scale: float = 1e-5,
conditioned_on_time: bool = False,
fourier_embedding_scale: float = 16.0,
dropout_rate: float = 0.0,
pad_time_to: Optional[int] = None,
pad_dimension_to: Optional[int] = None,
**_,
):
# Network topology is a flavor of UNet, example chart for num_resolutions=4
#
# 1: Image → Image/2 → Image/4 → Image/8
# ↓ ↓ ↓ ↓
# 2: Hidden → Hidden/2 → Hidden/4 → Hidden/8
# ↓ ↓ ↓ ↓
# 3: Hidden ← Hidden/2 ← Hidden/4 ← Hidden/8
# ↓ ↓ ↓ ↓
# 4: Image ← Image/2 ← Image/4 ← Image/8
# Horizontal arrows in (1) are downsampling
# Vertical arrows from (1) to (2) are channel upconversions
#
# Horizontal arrows in (2) are blocks with downsampling where necessary
# Horizontal arrows in (3) are blocks with upsampling where necessary
#
# Vertical arrows from (1) to (2) are downsampling and channel upconversioins
# Vertical arrows from (2) to (3) are sums connections (also with / sqrt(2))
# Vertical arrows from (3) to (4) are channel downconversions
# Horizontal arrows in (4) are upsampling and addition
super().__init__()
# same nonlinearity is used throughout the whole network
self.activation: torch.nn.Module = activation_registry[nonlinearity]()
self.init_scale: float = init_scale
self.downsample = torch.nn.Upsample(scale_factor=0.5, mode="bilinear")
self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear")
self.in_channels = in_channels
self.out_channels = out_channels
self.channels = channels
self.num_res_blocks = num_res_blocks
self.num_resolutions = num_resolutions
self.conditioned_on_time = conditioned_on_time
# padding setup
self.pad_time_to = pad_time_to or 2**self.num_resolutions
self.pad_dimension_to = pad_dimension_to or 2**self.num_resolutions
if self.conditioned_on_time:
self.time_embedding = torch.nn.Sequential(
GaussianFourierProjection(embedding_size=self.channels[0], scale=fourier_embedding_scale),
torch.nn.Linear(self.channels[0] * 2, self.channels[0] * 4),
self.activation,
torch.nn.Linear(self.channels[0] * 4, self.channels[0] * 4),
)
self.input_pyramid = torch.nn.ModuleList()
for ch in self.channels[:-1]:
self.input_pyramid.append(torch.nn.Conv2d(in_channels=self.in_channels, out_channels=ch, kernel_size=1))
# each block takes an image and outputs an image
# possibly changes number of channels
# output blocks ("reverse" path of the unet) reuse outputs of input blocks ("forward" path)
# so great care must be taken to in/out channels of each block
# resolutions are handled in `forward`
block_params = {
"activation": self.activation,
"dropout_rate": dropout_rate,
"init_scale": self.init_scale,
"diffusion_step_embedding_dim": channels[0] * 4 if self.conditioned_on_time else None,
}
self.input_blocks = torch.nn.ModuleList()
for in_ch, out_ch in zip(self.channels[:-1], self.channels[1:]):
for n in range(num_res_blocks):
block = ResnetBlockBigGANPlusPlus(in_ch=in_ch if n == 0 else out_ch, out_ch=out_ch, **block_params)
self.input_blocks.append(block)
self.output_blocks = torch.nn.ModuleList()
for in_ch, out_ch in zip(reversed(self.channels[1:]), reversed(self.channels[:-1])):
for n in reversed(range(num_res_blocks)):
block = ResnetBlockBigGANPlusPlus(in_ch=in_ch, out_ch=out_ch if n == 0 else in_ch, **block_params)
self.output_blocks.append(block)
self.projection_blocks = torch.nn.ModuleList()
for ch in self.channels[:-1]:
self.projection_blocks.append(torch.nn.Conv2d(ch, out_channels, kernel_size=1))
assert len(self.input_pyramid) == self.num_resolutions
assert len(self.input_blocks) == self.num_resolutions * self.num_res_blocks
assert len(self.output_blocks) == self.num_resolutions * self.num_res_blocks
assert len(self.projection_blocks) == self.num_resolutions
self.init_weights_()
logging.debug('Initialized %s with', self.__class__.__name__)
logging.debug('\tin_channels: %s', self.in_channels)
logging.debug('\tout_channels: %s', self.out_channels)
logging.debug('\tchannels: %s', self.channels)
logging.debug('\tnum_res_blocks: %s', self.num_res_blocks)
logging.debug('\tnum_resolutions: %s', self.num_resolutions)
logging.debug('\tconditioned_on_time: %s', self.conditioned_on_time)
logging.debug('\tpad_time_to: %s', self.pad_time_to)
logging.debug('\tpad_dimension_to: %s', self.pad_dimension_to)
[docs]
def init_weights_(self):
for module in self.modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
# torch.nn submodules with scaled init
for module in self.projection_blocks:
torch.nn.init.xavier_uniform_(module.weight, gain=self.init_scale)
# non-torch.nn submodules can have their own init schemes
for module in self.modules():
if module is self:
continue
if hasattr(module, "init_weights_"):
module.init_weights_()
@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"input": NeuralType(('B', 'C', 'D', 'T'), VoidType()),
"input_length": NeuralType(('B',), LengthsType(), optional=True),
"condition": NeuralType(('B',), FloatType(), optional=True),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"output": NeuralType(('B', 'C', 'D', 'T'), VoidType()),
"output_length": NeuralType(('B',), LengthsType(), optional=True),
}
[docs]
@typecheck()
def forward(
self, *, input: torch.Tensor, input_length: Optional[torch.Tensor], condition: Optional[torch.Tensor] = None
):
"""Forward pass of the model.
Args:
input: input tensor, shjae (B, C, D, T)
input_length: length of the valid time steps for each example in the batch, shape (B,)
condition: scalar condition (time) for the model, will be embedded using `self.time_embedding`
"""
assert input.shape[1] == self.in_channels
# apply padding at the input
*_, D, T = input.shape
input = self.pad_input(input=input)
if input_length is None:
# assume all time frames are valid
input_length = torch.LongTensor([input.shape[-1]] * input.shape[0]).to(input.device)
lengths = input_length
if condition is not None:
if len(condition.shape) != 1:
raise ValueError(
f"Expected conditon to be a 1-dim tensor, got a {len(condition.shape)}-dim tensor of shape {tuple(condition.shape)}"
)
if condition.shape[0] != input.shape[0]:
raise ValueError(
f"Condition {tuple(condition.shape)} and input {tuple(input.shape)} should match along the batch dimension"
)
condition = self.time_embedding(torch.log(condition))
# downsample and project input image to add later in the downsampling path
pyramid = [input]
for resolution_num in range(self.num_resolutions - 1):
pyramid.append(self.downsample(pyramid[-1]))
pyramid = [block(image) for image, block in zip(pyramid, self.input_pyramid)]
# downsampling path
history = []
hidden = torch.zeros_like(pyramid[0])
input_blocks = iter(self.input_blocks)
for resolution_num, image in enumerate(pyramid):
hidden = (hidden + image) / math.sqrt(2.0)
hidden = mask_sequence_tensor(hidden, lengths)
for _ in range(self.num_res_blocks):
hidden = next(input_blocks)(hidden, condition)
hidden = mask_sequence_tensor(hidden, lengths)
history.append(hidden)
final_resolution = resolution_num == self.num_resolutions - 1
if not final_resolution:
hidden = self.downsample(hidden)
lengths = (lengths / 2).ceil().long()
# upsampling path
to_project = []
for residual, block in zip(reversed(history), self.output_blocks):
if hidden.shape != residual.shape:
to_project.append(hidden)
hidden = self.upsample(hidden)
lengths = (lengths * 2).long()
hidden = (hidden + residual) / math.sqrt(2.0)
hidden = block(hidden, condition)
hidden = mask_sequence_tensor(hidden, lengths)
to_project.append(hidden)
# projecting to images
images = []
for tensor, projection in zip(to_project, reversed(self.projection_blocks)):
image = projection(tensor)
images.append(F.interpolate(image, size=input.shape[-2:])) # TODO write this loop using self.upsample
result = sum(images)
assert result.shape[-2:] == input.shape[-2:]
# remove padding
result = result[:, :, :D, :T]
return result, input_length
[docs]
class GaussianFourierProjection(NeuralModule):
"""Gaussian Fourier embeddings for input scalars.
The input scalars are typically time or noise levels.
"""
def __init__(self, embedding_size: int = 256, scale: float = 1.0):
super().__init__()
self.W = torch.nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"input": NeuralType(('B',), FloatType()),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"output": NeuralType(('B', 'D'), VoidType()),
}
[docs]
def forward(self, input):
x_proj = input[:, None] * self.W[None, :] * 2 * math.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
[docs]
class ResnetBlockBigGANPlusPlus(torch.nn.Module):
"""Implementation of a ResNet block for the BigGAN model.
References:
- Song et al., Score-Based Generative Modeling through Stochastic Differential Equations, NeurIPS 2021
- Brock et al., Large scale GAN training for high fidelity natural image synthesis, ICLR 2018
"""
def __init__(
self,
activation: torch.nn.Module,
in_ch: int,
out_ch: int,
diffusion_step_embedding_dim: Optional[int] = None,
init_scale: float = 1e-5,
dropout_rate: float = 0.1,
in_num_groups: Optional[int] = None,
out_num_groups: Optional[int] = None,
eps: float = 1e-6,
):
"""
Args:
activation (torch.nn.Module): activation layer (ReLU, SiLU, etc)
in_ch (int): number of channels in the input image
out_ch (int, optional): number of channels in the output image
diffusion_step_embedding_dim (int, optional): dimension of diffusion timestep embedding. Defaults to None (no embedding).
dropout_rate (float, optional): dropout rate. Defaults to 0.1.
init_scale (float, optional): scaling for weight initialization. Defaults to 0.0.
in_num_groups (int, optional): num_groups in the first GroupNorm. Defaults to min(in_ch // 4, 32)
out_num_groups (int, optional): num_groups in the second GroupNorm. Defaults to min(out_ch // 4, 32)
eps (float, optional): eps parameter of GroupNorms. Defaults to 1e-6.
"""
super().__init__()
in_num_groups = in_num_groups or min(in_ch // 4, 32)
out_num_groups = out_num_groups or min(out_ch // 4, 32)
self.init_scale = init_scale
self.input_block = torch.nn.Sequential(
torch.nn.GroupNorm(num_groups=in_num_groups, num_channels=in_ch, eps=eps),
activation,
)
self.middle_conv = torch.nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1)
if diffusion_step_embedding_dim is not None:
self.diffusion_step_projection = torch.nn.Sequential(
activation,
torch.nn.Linear(diffusion_step_embedding_dim, out_ch),
einops.layers.torch.Rearrange("batch dim -> batch dim 1 1"),
)
self.output_block = torch.nn.Sequential(
torch.nn.GroupNorm(num_groups=out_num_groups, num_channels=out_ch, eps=eps),
activation,
torch.nn.Dropout(dropout_rate),
torch.nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1),
)
if in_ch != out_ch:
self.residual_projection = torch.nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=1)
self.act = activation
self.in_ch = in_ch
self.out_ch = out_ch
self.init_weights_()
[docs]
def init_weights_(self):
"""Weight initialization"""
for module in self.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
# a single Conv2d is initialized with gain
torch.nn.init.xavier_uniform_(self.output_block[-1].weight, gain=self.init_scale)
[docs]
def forward(self, x: torch.Tensor, diffusion_time_embedding: Optional[torch.Tensor] = None):
"""Forward pass of the model.
Args:
x: input tensor
diffusion_time_embedding: embedding of the diffusion time step
Returns:
Output tensor
"""
h = self.input_block(x)
h = self.middle_conv(h)
if diffusion_time_embedding is not None:
h = h + self.diffusion_step_projection(diffusion_time_embedding)
h = self.output_block(h)
if x.shape != h.shape: # matching number of channels
x = self.residual_projection(x)
return (x + h) / math.sqrt(2.0)