Source code for nemo.collections.asr.modules.rnnt

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from omegaconf import DictConfig

from nemo.collections.asr.modules import rnnt_abstract
from nemo.collections.asr.parts.submodules import stateless_net
from nemo.collections.asr.parts.utils import adapter_utils, rnnt_utils
from nemo.collections.common.parts import rnn
from nemo.core.classes import adapter_mixins, typecheck
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.mixins import AdapterModuleMixin
from nemo.core.neural_types import (
    AcousticEncodedRepresentation,
    ElementType,
    EmbeddedTextType,
    LabelsType,
    LengthsType,
    LogprobsType,
    LossType,
    NeuralType,
    SpectrogramType,
)
from nemo.utils import logging


[docs] class StatelessTransducerDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable): """A Stateless Neural Network Transducer Decoder / Prediction Network. An RNN-T Decoder/Prediction stateless network that simply takes concatenation of embeddings of the history tokens as the output. Args: prednet: A dict-like object which contains the following key-value pairs. pred_hidden: int specifying the hidden dimension of the prediction net. dropout: float, set to 0.0 by default. Optional dropout applied at the end of the final LSTM RNN layer. vocab_size: int, specifying the vocabulary size of the embedding layer of the Prediction network, excluding the RNNT blank token. context_size: int, specifying the size of the history context used for this decoder. normalization_mode: Can be either None, 'layer'. By default, is set to None. Defines the type of normalization applied to the RNN layer. """ @property def input_types(self): """Returns definitions of module input ports.""" return { "targets": NeuralType(('B', 'T'), LabelsType()), "target_length": NeuralType(tuple('B'), LengthsType()), "states": [NeuralType(('B', 'T'), LabelsType(), optional=True)], } @property def output_types(self): """Returns definitions of module output ports.""" return { "outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()), "prednet_lengths": NeuralType(tuple('B'), LengthsType()), "states": [NeuralType(('B', 'T'), LabelsType(), optional=True)], }
[docs] def input_example(self, max_batch=1, max_dim=1): """ Generates input examples for tracing etc. Returns: A tuple of input examples. """ length = max_dim targets = torch.full(fill_value=self.blank_idx, size=(max_batch, length), dtype=torch.int32).to( next(self.parameters()).device ) target_length = torch.randint(0, length, size=(max_batch,), dtype=torch.int32).to( next(self.parameters()).device ) states = tuple(self.initialize_state(targets.float())) return (targets, target_length, states)
def _prepare_for_export(self, **kwargs): self._rnnt_export = True super()._prepare_for_export(**kwargs) def __init__( self, prednet: Dict[str, Any], vocab_size: int, context_size: int = 1, normalization_mode: Optional[str] = None, ): # Required arguments self.pred_hidden = prednet['pred_hidden'] self.blank_idx = vocab_size self.context_size = context_size # Initialize the model (blank token increases vocab size by 1) super().__init__(vocab_size=vocab_size, blank_idx=self.blank_idx, blank_as_pad=True) # Optional arguments dropout = prednet.get('dropout', 0.0) self.prediction = self._predict_modules( **{ "context_size": context_size, "vocab_size": vocab_size, "emb_dim": self.pred_hidden, "blank_idx": self.blank_idx, "normalization_mode": normalization_mode, "dropout": dropout, } ) self._rnnt_export = False
[docs] @typecheck() def forward(self, targets, target_length, states=None): # y: (B, U) y = rnn.label_collate(targets) # state maintenance is unnecessary during training forward call # to get state, use .predict() method. if self._rnnt_export: add_sos = False else: add_sos = True g, state = self.predict(y, state=states, add_sos=add_sos) # (B, U, D) g = g.transpose(1, 2) # (B, D, U) return g, target_length, state
[docs] def predict( self, y: Optional[torch.Tensor] = None, state: Optional[torch.Tensor] = None, add_sos: bool = True, batch_size: Optional[int] = None, ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """ Stateful prediction of scores and state for a tokenset. Here: B - batch size U - label length C - context size for stateless decoder D - total embedding size Args: y: Optional torch tensor of shape [B, U] of dtype long which will be passed to the Embedding. If None, creates a zero tensor of shape [B, 1, D] which mimics output of pad-token on Embedding. state: An optional one-element list of one tensor. The tensor is used to store previous context labels. The tensor uses type long and is of shape [B, C]. add_sos: bool flag, whether a zero vector describing a "start of signal" token should be prepended to the above "y" tensor. When set, output size is (B, U + 1, D). batch_size: An optional int, specifying the batch size of the `y` tensor. Can be infered if `y` and `state` is None. But if both are None, then batch_size cannot be None. Returns: A tuple (g, state) such that - If add_sos is False: g: (B, U, D) state: [(B, C)] storing the history context including the new words in y. If add_sos is True: g: (B, U + 1, D) state: [(B, C)] storing the history context including the new words in y. """ # Get device and dtype of current module _p = next(self.parameters()) device = _p.device dtype = _p.dtype # If y is not None, it is of shape [B, U] with dtype long. if y is not None: if y.device != device: y = y.to(device) y, state = self.prediction(y, state) else: # Y is not provided, assume zero tensor with shape [B, 1, D] is required # Emulates output of embedding of pad token. if batch_size is None: B = 1 if state is None else state[0].size(1) else: B = batch_size y = torch.zeros((B, 1, self.pred_hidden), device=device, dtype=dtype) # Prepend blank "start of sequence" symbol (zero tensor) if add_sos: B, U, D = y.shape start = torch.zeros((B, 1, D), device=y.device, dtype=y.dtype) y = torch.cat([start, y], dim=1).contiguous() # (B, U + 1, D) else: start = None # makes del call later easier del start return y, state
def _predict_modules(self, **kwargs): """ Prepare the trainable parameters of the Prediction Network. Args: vocab_size: Vocab size (excluding the blank token). pred_n_hidden: Hidden size of the RNNs. norm: Type of normalization to perform in RNN. dropout: Whether to apply dropout to RNN. """ net = stateless_net.StatelessNet(**kwargs) return net
[docs] def score_hypothesis( self, hypothesis: rnnt_utils.Hypothesis, cache: Dict[Tuple[int], Any] ) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: """ Similar to the predict() method, instead this method scores a Hypothesis during beam search. Hypothesis is a dataclass representing one hypothesis in a Beam Search. Args: hypothesis: Refer to rnnt_utils.Hypothesis. cache: Dict which contains a cache to avoid duplicate computations. Returns: Returns a tuple (y, states, lm_token) such that: y is a torch.Tensor of shape [1, 1, H] representing the score of the last token in the Hypothesis. state is a list of RNN states, each of shape [L, 1, H]. lm_token is the final integer token of the hypothesis. """ if hypothesis.dec_state is not None: device = hypothesis.dec_state[0].device else: _p = next(self.parameters()) device = _p.device # parse "blank" tokens in hypothesis if len(hypothesis.y_sequence) > 0 and hypothesis.y_sequence[-1] == self.blank_idx: blank_state = True else: blank_state = False # Convert last token of hypothesis to torch.Tensor target = torch.full([1, 1], fill_value=hypothesis.y_sequence[-1], device=device, dtype=torch.long) lm_token = target[:, -1] # [1] # Convert current hypothesis into a tuple to preserve in cache sequence = tuple(hypothesis.y_sequence) if sequence in cache: y, new_state = cache[sequence] else: # Obtain score for target token and new states if blank_state: y, new_state = self.predict(None, state=None, add_sos=False, batch_size=1) # [1, 1, H] else: y, new_state = self.predict( target, state=hypothesis.dec_state, add_sos=False, batch_size=1 ) # [1, 1, H] y = y[:, -1:, :] # Extract just last state : [1, 1, H] cache[sequence] = (y, new_state) return y, new_state, lm_token
[docs] def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]: batch = y.size(0) # state contains context_size - 1 elements for each utterance in batch, # consistent with the state returned from StatelessNet.forward state = [ torch.full([batch, self.context_size - 1], fill_value=self.blank_idx, dtype=torch.long, device=y.device) ] return state
[docs] def batch_initialize_states(self, decoder_states: List[List[torch.Tensor]]): """ Creates a stacked decoder states to be passed to prediction network. Args: decoder_states (list of list of torch.Tensor): list of decoder states [B, 1, C] - B: Batch size. - C: Dimensionality of the hidden state. Returns: batch_states (list of torch.Tensor): batch of decoder states [[B x C]] """ new_state = torch.stack([s[0] for s in decoder_states]) return [new_state]
[docs] def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List[List[torch.Tensor]]: """Get decoder state from batch of states, for given id. Args: batch_states (list): batch of decoder states [(B, C)] idx (int): index to extract state from batch of states Returns: (tuple): decoder states for given id [(C)] """ if batch_states is not None: states = batch_states[0][idx] states = ( states.long() ) # beam search code assumes the batch_states tensor is always of float type, so need conversion return [states] else: return None
[docs] def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[torch.Tensor]: """Concatenate a batch of decoder state to a packed state. Args: batch_states (list): batch of decoder states B x ([(C)] Returns: (tuple): decoder states [(B x C)] """ state_list = [] batch_list = [] for sample_id in range(len(batch_states)): tensor = torch.stack(batch_states[sample_id]) # [1, H] batch_list.append(tensor) state_tensor = torch.cat(batch_list, 0) # [B, H] state_list.append(state_tensor) return state_list
[docs] @classmethod def batch_replace_states_mask( cls, src_states: tuple[torch.Tensor, torch.Tensor] | list[torch.Tensor], dst_states: tuple[torch.Tensor, torch.Tensor] | list[torch.Tensor], mask: torch.Tensor, other_src_states: Optional[tuple[torch.Tensor, torch.Tensor] | list[torch.Tensor]] = None, ): """ Replaces states in `dst_states` with states from `src_states` based on the given `mask`. Args: mask (torch.Tensor): When True, selects values from `src_states`, otherwise `out` or `other_src_states`(if provided). src_states (tuple[torch.Tensor, torch.Tensor]): Values selected at indices where `mask` is True. dst_states (tuple[torch.Tensor, torch.Tensor], optional): The output states. other_src_states (tuple[torch.Tensor, torch.Tensor], optional): Values selected at indices where `mask` is False. Note: This operation is performed without CPU-GPU synchronization by using `torch.where`. """ other = other_src_states if other_src_states is not None else dst_states # same as `dst_states[0][mask] = src_states[0][mask]`, but non-blocking torch.where(mask.unsqueeze(-1), src_states[0], other[0], out=dst_states[0])
[docs] @classmethod def batch_replace_states_all( cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], batch_size: int | None = None, ): """Replace states in dst_states with states from src_states""" if batch_size is None: dst_states[0].copy_(src_states[0]) else: dst_states[0][:batch_size].copy_(src_states[0][:batch_size])
[docs] @classmethod def clone_state(cls, state: list[torch.Tensor]) -> list[torch.Tensor]: """Return copy of the states""" return [sub_state.clone() for sub_state in state]
[docs] @classmethod def batch_split_states(cls, batch_states: list[torch.Tensor]) -> list[list[torch.Tensor]]: """ Split states into a list of states. Useful for splitting the final state for converting results of the decoding algorithm to Hypothesis class. """ return [sub_state.split(1, dim=0) for sub_state in batch_states]
[docs] @classmethod def batch_unsplit_states( cls, batch_states: list[list[torch.Tensor]], device=None, dtype=None ) -> list[torch.Tensor]: """ Concatenate a batch of decoder state to a packed state. Inverse of `batch_split_states`. """ return [ torch.stack([state[0] for state in batch_states], dim=0).to(device=device, dtype=dtype), ]
[docs] def batch_copy_states( self, old_states: List[torch.Tensor], new_states: List[torch.Tensor], ids: List[int], value: Optional[float] = None, ) -> List[torch.Tensor]: """Copy states from new state to old state at certain indices. Args: old_states: packed decoder states single element list of (B x C) new_states: packed decoder states single element list of (B x C) ids (list): List of indices to copy states at. value (optional float): If a value should be copied instead of a state slice, a float should be provided Returns: batch of decoder states with partial copy at ids (or a specific value). (B x C) """ if value is None: old_states[0][ids, :] = new_states[0][ids, :] return old_states
[docs] def mask_select_states( self, states: Optional[List[torch.Tensor]], mask: torch.Tensor ) -> Optional[List[torch.Tensor]]: """ Return states by mask selection Args: states: states for the batch mask: boolean mask for selecting states; batch dimension should be the same as for states Returns: states filtered by mask """ if states is None: return None return [states[0][mask]]
[docs] def batch_score_hypothesis( self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: """ Used for batched beam search algorithms. Similar to score_hypothesis method. Args: hypothesis: List of Hypotheses. Refer to rnnt_utils.Hypothesis. cache: Dict which contains a cache to avoid duplicate computations. Returns: Returns a tuple (batch_dec_out, batch_dec_states) such that: batch_dec_out: a list of torch.Tensor [1, H] representing the prediction network outputs for the last tokens in the Hypotheses. batch_dec_states: a list of list of RNN states, each of shape [L, B, H]. Represented as B x List[states]. """ final_batch = len(hypotheses) if final_batch == 0: raise ValueError("No hypotheses was provided for the batch!") _p = next(self.parameters()) device = _p.device tokens = [] to_process = [] final = [None for _ in range(final_batch)] # For each hypothesis, cache the last token of the sequence and the current states for final_idx, hyp in enumerate(hypotheses): sequence = tuple(hyp.y_sequence) if sequence in cache: final[final_idx] = cache[sequence] else: tokens.append(hyp.y_sequence[-1]) to_process.append((sequence, hyp.dec_state)) if to_process: batch = len(to_process) # convert list of tokens to torch.Tensor, then reshape. tokens = torch.tensor(tokens, device=device, dtype=torch.long).view(batch, -1) dec_states = self.batch_initialize_states([d_state for _, d_state in to_process]) dec_outputs, dec_states = self.predict( tokens, state=dec_states, add_sos=False, batch_size=batch ) # [B, 1, H], B x List([L, 1, H]) # Update final states and cache shared by entire batch. processed_idx = 0 for final_idx in range(final_batch): if to_process and final[final_idx] is None: # Select sample's state from the batch state list new_state = self.batch_select_state(dec_states, processed_idx) # Cache [1, H] scores of the current y_j, and its corresponding state final[final_idx] = (dec_outputs[processed_idx], new_state) cache[to_process[processed_idx][0]] = (dec_outputs[processed_idx], new_state) processed_idx += 1 return [dec_out for dec_out, _ in final], [dec_states for _, dec_states in final]
[docs] class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable, AdapterModuleMixin): """A Recurrent Neural Network Transducer Decoder / Prediction Network (RNN-T Prediction Network). An RNN-T Decoder/Prediction network, comprised of a stateful LSTM model. Args: prednet: A dict-like object which contains the following key-value pairs. pred_hidden: int specifying the hidden dimension of the prediction net. pred_rnn_layers: int specifying the number of rnn layers. Optionally, it may also contain the following: forget_gate_bias: float, set by default to 1.0, which constructs a forget gate initialized to 1.0. Reference: [An Empirical Exploration of Recurrent Network Architectures](http://proceedings.mlr.press/v37/jozefowicz15.pdf) t_max: int value, set to None by default. If an int is specified, performs Chrono Initialization of the LSTM network, based on the maximum number of timesteps `t_max` expected during the course of training. Reference: [Can recurrent neural networks warp time?](https://openreview.net/forum?id=SJcKhk-Ab) weights_init_scale: Float scale of the weights after initialization. Setting to lower than one sometimes helps reduce variance between runs. hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for the default behaviour. dropout: float, set to 0.0 by default. Optional dropout applied at the end of the final LSTM RNN layer. vocab_size: int, specifying the vocabulary size of the embedding layer of the Prediction network, excluding the RNNT blank token. normalization_mode: Can be either None, 'batch' or 'layer'. By default, is set to None. Defines the type of normalization applied to the RNN layer. random_state_sampling: bool, set to False by default. When set, provides normal-distribution sampled state tensors instead of zero tensors during training. Reference: [Recognizing long-form speech using streaming end-to-end models](https://arxiv.org/abs/1910.11455) blank_as_pad: bool, set to True by default. When set, will add a token to the Embedding layer of this prediction network, and will treat this token as a pad token. In essence, the RNNT pad token will be treated as a pad token, and the embedding layer will return a zero tensor for this token. It is set by default as it enables various batch optimizations required for batched beam search. Therefore, it is not recommended to disable this flag. """ @property def input_types(self): """Returns definitions of module input ports.""" return { "targets": NeuralType(('B', 'T'), LabelsType()), "target_length": NeuralType(tuple('B'), LengthsType()), "states": [NeuralType(('D', 'B', 'D'), ElementType(), optional=True)], # must always be last } @property def output_types(self): """Returns definitions of module output ports.""" return { "outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()), "prednet_lengths": NeuralType(tuple('B'), LengthsType()), "states": [NeuralType((('D', 'B', 'D')), ElementType(), optional=True)], # must always be last }
[docs] def input_example(self, max_batch=1, max_dim=1): """ Generates input examples for tracing etc. Returns: A tuple of input examples. """ length = max_dim targets = torch.full(fill_value=self.blank_idx, size=(max_batch, length), dtype=torch.int32).to( next(self.parameters()).device ) target_length = torch.randint(0, length, size=(max_batch,), dtype=torch.int32).to( next(self.parameters()).device ) states = tuple(self.initialize_state(targets.float())) return (targets, target_length, states)
def _prepare_for_export(self, **kwargs): self._rnnt_export = True super()._prepare_for_export(**kwargs) def __init__( self, prednet: Dict[str, Any], vocab_size: int, normalization_mode: Optional[str] = None, random_state_sampling: bool = False, blank_as_pad: bool = True, ): # Required arguments self.pred_hidden = prednet['pred_hidden'] self.pred_rnn_layers = prednet["pred_rnn_layers"] self.blank_idx = vocab_size # Initialize the model (blank token increases vocab size by 1) super().__init__(vocab_size=vocab_size, blank_idx=self.blank_idx, blank_as_pad=blank_as_pad) # Optional arguments forget_gate_bias = prednet.get('forget_gate_bias', 1.0) t_max = prednet.get('t_max', None) weights_init_scale = prednet.get('weights_init_scale', 1.0) hidden_hidden_bias_scale = prednet.get('hidden_hidden_bias_scale', 0.0) dropout = prednet.get('dropout', 0.0) self.random_state_sampling = random_state_sampling self.prediction = self._predict_modules( vocab_size=vocab_size, # add 1 for blank symbol pred_n_hidden=self.pred_hidden, pred_rnn_layers=self.pred_rnn_layers, forget_gate_bias=forget_gate_bias, t_max=t_max, norm=normalization_mode, weights_init_scale=weights_init_scale, hidden_hidden_bias_scale=hidden_hidden_bias_scale, dropout=dropout, rnn_hidden_size=prednet.get("rnn_hidden_size", -1), ) self._rnnt_export = False
[docs] @typecheck() def forward(self, targets, target_length, states=None): # y: (B, U) y = rnn.label_collate(targets) # state maintenance is unnecessary during training forward call # to get state, use .predict() method. if self._rnnt_export: add_sos = False else: add_sos = True g, states = self.predict(y, state=states, add_sos=add_sos) # (B, U, D) g = g.transpose(1, 2) # (B, D, U) return g, target_length, states
[docs] def predict( self, y: Optional[torch.Tensor] = None, state: Optional[List[torch.Tensor]] = None, add_sos: bool = True, batch_size: Optional[int] = None, ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """ Stateful prediction of scores and state for a (possibly null) tokenset. This method takes various cases into consideration : - No token, no state - used for priming the RNN - No token, state provided - used for blank token scoring - Given token, states - used for scores + new states Here: B - batch size U - label length H - Hidden dimension size of RNN L - Number of RNN layers Args: y: Optional torch tensor of shape [B, U] of dtype long which will be passed to the Embedding. If None, creates a zero tensor of shape [B, 1, H] which mimics output of pad-token on EmbeddiNg. state: An optional list of states for the RNN. Eg: For LSTM, it is the state list length is 2. Each state must be a tensor of shape [L, B, H]. If None, and during training mode and `random_state_sampling` is set, will sample a normal distribution tensor of the above shape. Otherwise, None will be passed to the RNN. add_sos: bool flag, whether a zero vector describing a "start of signal" token should be prepended to the above "y" tensor. When set, output size is (B, U + 1, H). batch_size: An optional int, specifying the batch size of the `y` tensor. Can be infered if `y` and `state` is None. But if both are None, then batch_size cannot be None. Returns: A tuple (g, hid) such that - If add_sos is False: g: (B, U, H) hid: (h, c) where h is the final sequence hidden state and c is the final cell state: h (tensor), shape (L, B, H) c (tensor), shape (L, B, H) If add_sos is True: g: (B, U + 1, H) hid: (h, c) where h is the final sequence hidden state and c is the final cell state: h (tensor), shape (L, B, H) c (tensor), shape (L, B, H) """ # Get device and dtype of current module _p = next(self.parameters()) device = _p.device dtype = _p.dtype # If y is not None, it is of shape [B, U] with dtype long. if y is not None: if y.device != device: y = y.to(device) # (B, U) -> (B, U, H) y = self.prediction["embed"](y) else: # Y is not provided, assume zero tensor with shape [B, 1, H] is required # Emulates output of embedding of pad token. if batch_size is None: B = 1 if state is None else state[0].size(1) else: B = batch_size y = torch.zeros((B, 1, self.pred_hidden), device=device, dtype=dtype) # Prepend blank "start of sequence" symbol (zero tensor) if add_sos: B, U, H = y.shape start = torch.zeros((B, 1, H), device=y.device, dtype=y.dtype) y = torch.cat([start, y], dim=1).contiguous() # (B, U + 1, H) else: start = None # makes del call later easier # If in training mode, and random_state_sampling is set, # initialize state to random normal distribution tensor. if state is None: if self.random_state_sampling and self.training: state = self.initialize_state(y) # Forward step through RNN y = y.transpose(0, 1) # (U + 1, B, H) g, hid = self.prediction["dec_rnn"](y, state) g = g.transpose(0, 1) # (B, U + 1, H) del y, start, state # Adapter module forward step if self.is_adapter_available(): g = self.forward_enabled_adapters(g) return g, hid
def _predict_modules( self, vocab_size, pred_n_hidden, pred_rnn_layers, forget_gate_bias, t_max, norm, weights_init_scale, hidden_hidden_bias_scale, dropout, rnn_hidden_size, ): """ Prepare the trainable parameters of the Prediction Network. Args: vocab_size: Vocab size (excluding the blank token). pred_n_hidden: Hidden size of the RNNs. pred_rnn_layers: Number of RNN layers. forget_gate_bias: Whether to perform unit forget gate bias. t_max: Whether to perform Chrono LSTM init. norm: Type of normalization to perform in RNN. weights_init_scale: Float scale of the weights after initialization. Setting to lower than one sometimes helps reduce variance between runs. hidden_hidden_bias_scale: Float scale for the hidden-to-hidden bias scale. Set to 0.0 for the default behaviour. dropout: Whether to apply dropout to RNN. rnn_hidden_size: the hidden size of the RNN, if not specified, pred_n_hidden would be used """ if self.blank_as_pad: embed = torch.nn.Embedding(vocab_size + 1, pred_n_hidden, padding_idx=self.blank_idx) else: embed = torch.nn.Embedding(vocab_size, pred_n_hidden) layers = torch.nn.ModuleDict( { "embed": embed, "dec_rnn": rnn.rnn( input_size=pred_n_hidden, hidden_size=rnn_hidden_size if rnn_hidden_size > 0 else pred_n_hidden, num_layers=pred_rnn_layers, norm=norm, forget_gate_bias=forget_gate_bias, t_max=t_max, dropout=dropout, weights_init_scale=weights_init_scale, hidden_hidden_bias_scale=hidden_hidden_bias_scale, proj_size=pred_n_hidden if pred_n_hidden < rnn_hidden_size else 0, ), } ) return layers
[docs] def initialize_state(self, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Initialize the state of the LSTM layers, with same dtype and device as input `y`. LSTM accepts a tuple of 2 tensors as a state. Args: y: A torch.Tensor whose device the generated states will be placed on. Returns: Tuple of 2 tensors, each of shape [L, B, H], where L = Number of RNN layers B = Batch size H = Hidden size of RNN. """ batch = y.size(0) if self.random_state_sampling and self.training: state = ( torch.randn(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device), torch.randn(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device), ) else: state = ( torch.zeros(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device), torch.zeros(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device), ) return state
[docs] def score_hypothesis( self, hypothesis: rnnt_utils.Hypothesis, cache: Dict[Tuple[int], Any] ) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: """ Similar to the predict() method, instead this method scores a Hypothesis during beam search. Hypothesis is a dataclass representing one hypothesis in a Beam Search. Args: hypothesis: Refer to rnnt_utils.Hypothesis. cache: Dict which contains a cache to avoid duplicate computations. Returns: Returns a tuple (y, states, lm_token) such that: y is a torch.Tensor of shape [1, 1, H] representing the score of the last token in the Hypothesis. state is a list of RNN states, each of shape [L, 1, H]. lm_token is the final integer token of the hypothesis. """ if hypothesis.dec_state is not None: device = hypothesis.dec_state[0].device else: _p = next(self.parameters()) device = _p.device # parse "blank" tokens in hypothesis if len(hypothesis.y_sequence) > 0 and hypothesis.y_sequence[-1] == self.blank_idx: blank_state = True else: blank_state = False # Convert last token of hypothesis to torch.Tensor target = torch.full([1, 1], fill_value=hypothesis.y_sequence[-1], device=device, dtype=torch.long) lm_token = target[:, -1] # [1] # Convert current hypothesis into a tuple to preserve in cache sequence = tuple(hypothesis.y_sequence) if sequence in cache: y, new_state = cache[sequence] else: # Obtain score for target token and new states if blank_state: y, new_state = self.predict(None, state=None, add_sos=False, batch_size=1) # [1, 1, H] else: y, new_state = self.predict( target, state=hypothesis.dec_state, add_sos=False, batch_size=1 ) # [1, 1, H] y = y[:, -1:, :] # Extract just last state : [1, 1, H] cache[sequence] = (y, new_state) return y, new_state, lm_token
[docs] def batch_score_hypothesis( self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]: """ Used for batched beam search algorithms. Similar to score_hypothesis method. Args: hypothesis: List of Hypotheses. Refer to rnnt_utils.Hypothesis. cache: Dict which contains a cache to avoid duplicate computations. Returns: Returns a tuple (batch_dec_out, batch_dec_states) such that: batch_dec_out: a list of torch.Tensor [1, H] representing the prediction network outputs for the last tokens in the Hypotheses. batch_dec_states: a list of list of RNN states, each of shape [L, B, H]. Represented as B x List[states]. """ final_batch = len(hypotheses) if final_batch == 0: raise ValueError("No hypotheses was provided for the batch!") _p = next(self.parameters()) device = _p.device tokens = [] to_process = [] final = [None for _ in range(final_batch)] # For each hypothesis, cache the last token of the sequence and the current states for final_idx, hyp in enumerate(hypotheses): sequence = tuple(hyp.y_sequence) if sequence in cache: final[final_idx] = cache[sequence] else: tokens.append(hyp.y_sequence[-1]) to_process.append((sequence, hyp.dec_state)) if to_process: batch = len(to_process) # convert list of tokens to torch.Tensor, then reshape. tokens = torch.tensor(tokens, device=device, dtype=torch.long).view(batch, -1) dec_states = self.batch_initialize_states([d_state for _, d_state in to_process]) dec_out, dec_states = self.predict( tokens, state=dec_states, add_sos=False, batch_size=batch ) # [B, 1, H], B x List([L, 1, H]) # Update final states and cache shared by entire batch. processed_idx = 0 for final_idx in range(final_batch): if final[final_idx] is None: # Select sample's state from the batch state list new_state = self.batch_select_state(dec_states, processed_idx) # Cache [1, H] scores of the current y_j, and its corresponding state final[final_idx] = (dec_out[processed_idx], new_state) cache[to_process[processed_idx][0]] = (dec_out[processed_idx], new_state) processed_idx += 1 return [dec_out for dec_out, _ in final], [dec_states for _, dec_states in final]
[docs] def batch_initialize_states(self, decoder_states: List[List[torch.Tensor]]) -> List[torch.Tensor]: """ Creates a stacked decoder states to be passed to prediction network Args: decoder_states (list of list of list of torch.Tensor): list of decoder states [B, C, L, H] - B: Batch size. - C: e.g., for LSTM, this is 2: hidden and cell states - L: Number of layers in prediction RNN. - H: Dimensionality of the hidden state. Returns: batch_states (list of torch.Tensor): batch of decoder states [C x torch.Tensor[L x B x H] """ # stack decoder states into tensor of shape [B x layers x L x H] # permute to the target shape [layers x L x B x H] stacked_states = torch.stack([torch.stack(decoder_state) for decoder_state in decoder_states]) permuted_states = stacked_states.permute(1, 2, 0, 3) return list(permuted_states.contiguous())
[docs] def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List[List[torch.Tensor]]: """Get decoder state from batch of states, for given id. Args: batch_states (list): batch of decoder states ([L x (B, H)], [L x (B, H)]) idx (int): index to extract state from batch of states Returns: (tuple): decoder states for given id ([L x (1, H)], [L x (1, H)]) """ if batch_states is not None: return [state[:, idx] for state in batch_states] return None
[docs] @classmethod def batch_aggregate_states_beam( cls, src_states: tuple[torch.Tensor, torch.Tensor], batch_size: int, beam_size: int, indices: torch.Tensor, dst_states: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Aggregates decoder states based on the given indices. Args: src_states (Tuple[torch.Tensor, torch.Tensor]): source states of shape `([L x (batch_size * beam_size, H)], [L x (batch_size * beam_size, H)])` batch_size (int): The size of the batch. beam_size (int): The size of the beam. indices (torch.Tensor): A tensor of shape `(batch_size, beam_size)` containing the indices in beam that map the source states to the destination states. dst_states (Optional[Tuple[torch.Tensor, torch.Tensor]]): If provided, the method updates these tensors in-place. Returns: Tuple[torch.Tensor, torch.Tensor]: Note: - The `indices` tensor is expanded to match the shape of the source states during the gathering operation. """ layers_num = src_states[0].shape[0] layers_dim = src_states[0].shape[-1] beam_shape = torch.Size((layers_num, batch_size, beam_size, layers_dim)) flat_shape = torch.Size((layers_num, batch_size * beam_size, layers_dim)) # Expand indices to match the source states' shape indices_expanded = indices[None, :, :, None].expand(beam_shape) if dst_states is not None: # Perform in-place gathering into dst_states torch.gather( src_states[0].view(beam_shape), dim=2, index=indices_expanded, out=dst_states[0].view(beam_shape) ) torch.gather( src_states[1].view(beam_shape), dim=2, index=indices_expanded, out=dst_states[1].view(beam_shape) ) return dst_states # Gather and reshape into the output format return ( torch.gather(src_states[0].view(beam_shape), dim=2, index=indices_expanded).view(flat_shape), torch.gather(src_states[1].view(beam_shape), dim=2, index=indices_expanded).view(flat_shape), )
[docs] def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[torch.Tensor]: """Concatenate a batch of decoder state to a packed state. Args: batch_states (list): batch of decoder states B x ([L x (H)], [L x (H)]) Returns: (tuple): decoder states (L x B x H, L x B x H) """ state_list = [] for state_id in range(len(batch_states[0])): batch_list = [] for sample_id in range(len(batch_states)): tensor = ( torch.stack(batch_states[sample_id][state_id]) if not isinstance(batch_states[sample_id][state_id], torch.Tensor) else batch_states[sample_id][state_id] ) # [L, H] tensor = tensor.unsqueeze(0) # [1, L, H] batch_list.append(tensor) state_tensor = torch.cat(batch_list, 0) # [B, L, H] state_tensor = state_tensor.transpose(1, 0) # [L, B, H] state_list.append(state_tensor) return state_list
[docs] @classmethod def batch_replace_states_mask( cls, src_states: Tuple[torch.Tensor, torch.Tensor], dst_states: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor, other_src_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): """ Replaces states in `dst_states` with states from `src_states` based on the given `mask`. Args: mask (torch.Tensor): When True, selects values from `src_states`, otherwise `out` or `other_src_states`(if provided). src_states (Tuple[torch.Tensor, torch.Tensor]): Values selected at indices where `mask` is True. dst_states (Tuple[torch.Tensor, torch.Tensor])): The output states. other_src_states (Tuple[torch.Tensor, torch.Tensor], optional): Values selected at indices where `mask` is False. Note: This operation is performed without CPU-GPU synchronization by using `torch.where`. """ # same as `dst_states[i][mask] = src_states[i][mask]`, but non-blocking # we need to cast, since LSTM is calculated in fp16 even if autocast to bfloat16 is enabled other = other_src_states if other_src_states is not None else dst_states dtype = dst_states[0].dtype torch.where(mask.unsqueeze(0).unsqueeze(-1), src_states[0].to(dtype), other[0].to(dtype), out=dst_states[0]) torch.where(mask.unsqueeze(0).unsqueeze(-1), src_states[1].to(dtype), other[1].to(dtype), out=dst_states[1])
[docs] @classmethod def batch_replace_states_all( cls, src_states: Tuple[torch.Tensor, torch.Tensor], dst_states: Tuple[torch.Tensor, torch.Tensor], batch_size: int | None = None, ): """Replace states in dst_states with states from src_states""" if batch_size is None: dst_states[0].copy_(src_states[0]) dst_states[1].copy_(src_states[1]) else: dst_states[0][:, :batch_size].copy_(src_states[0][:, :batch_size]) dst_states[1][:, :batch_size].copy_(src_states[1][:, :batch_size])
[docs] @classmethod def clone_state(cls, state: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: """Return copy of the states""" return state[0].clone(), state[1].clone()
[docs] @classmethod def batch_split_states( cls, batch_states: tuple[torch.Tensor, torch.Tensor] ) -> list[tuple[torch.Tensor, torch.Tensor]]: """ Split states into a list of states. Useful for splitting the final state for converting results of the decoding algorithm to Hypothesis class. """ return [ (sub_state_1.squeeze(1), sub_state_2.squeeze(1)) for sub_state_1, sub_state_2 in zip(batch_states[0].split(1, dim=1), batch_states[1].split(1, dim=1)) ]
[docs] @classmethod def batch_unsplit_states( cls, batch_states: list[tuple[torch.Tensor, torch.Tensor]], device=None, dtype=None ) -> tuple[torch.Tensor, torch.Tensor]: """ Concatenate a batch of decoder state to a packed state. Inverse of `batch_split_states`. Args: batch_states (list): batch of decoder states B x ([L x (H)], [L x (H)]) Returns: (tuple): decoder states (L x B x H, L x B x H) """ return ( torch.stack([state[0] for state in batch_states], dim=1).to(device=device, dtype=dtype), torch.stack([state[1] for state in batch_states], dim=1).to(device=device, dtype=dtype), )
[docs] def batch_copy_states( self, old_states: List[torch.Tensor], new_states: List[torch.Tensor], ids: List[int], value: Optional[float] = None, ) -> List[torch.Tensor]: """Copy states from new state to old state at certain indices. Args: old_states(list): packed decoder states (L x B x H, L x B x H) new_states: packed decoder states (L x B x H, L x B x H) ids (list): List of indices to copy states at. value (optional float): If a value should be copied instead of a state slice, a float should be provided Returns: batch of decoder states with partial copy at ids (or a specific value). (L x B x H, L x B x H) """ for state_id in range(len(old_states)): if value is None: old_states[state_id][:, ids, :] = new_states[state_id][:, ids, :] else: old_states[state_id][:, ids, :] *= 0.0 old_states[state_id][:, ids, :] += value return old_states
[docs] def mask_select_states( self, states: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Return states by mask selection Args: states: states for the batch mask: boolean mask for selecting states; batch dimension should be the same as for states Returns: states filtered by mask """ # LSTM in PyTorch returns a tuple of 2 tensors as a state return states[0][:, mask], states[1][:, mask]
# Adapter method overrides
[docs] def add_adapter(self, name: str, cfg: DictConfig): # Update the config with correct input dim cfg = self._update_adapter_cfg_input_dim(cfg) # Add the adapter super().add_adapter(name=name, cfg=cfg)
def _update_adapter_cfg_input_dim(self, cfg: DictConfig): cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.pred_hidden) return cfg
[docs] class RNNTJoint(rnnt_abstract.AbstractRNNTJoint, Exportable, AdapterModuleMixin): """A Recurrent Neural Network Transducer Joint Network (RNN-T Joint Network). An RNN-T Joint network, comprised of a feedforward model. Args: jointnet: A dict-like object which contains the following key-value pairs. encoder_hidden: int specifying the hidden dimension of the encoder net. pred_hidden: int specifying the hidden dimension of the prediction net. joint_hidden: int specifying the hidden dimension of the joint net activation: Activation function used in the joint step. Can be one of ['relu', 'tanh', 'sigmoid']. Optionally, it may also contain the following: dropout: float, set to 0.0 by default. Optional dropout applied at the end of the joint net. num_classes: int, specifying the vocabulary size that the joint network must predict, excluding the RNNT blank token. vocabulary: Optional list of strings/tokens that comprise the vocabulary of the joint network. Unused and kept only for easy access for character based encoding RNNT models. log_softmax: Optional bool, set to None by default. If set as None, will compute the log_softmax() based on the value provided. preserve_memory: Optional bool, set to False by default. If the model crashes due to the memory intensive joint step, one might try this flag to empty the tensor cache in pytorch. Warning: This will make the forward-backward pass much slower than normal. It also might not fix the OOM if the GPU simply does not have enough memory to compute the joint. fuse_loss_wer: Optional bool, set to False by default. Fuses the joint forward, loss forward and wer forward steps. In doing so, it trades of speed for memory conservation by creating sub-batches of the provided batch of inputs, and performs Joint forward, loss forward and wer forward (optional), all on sub-batches, then collates results to be exactly equal to results from the entire batch. When this flag is set, prior to calling forward, the fields `loss` and `wer` (either one) *must* be set using the `RNNTJoint.set_loss()` or `RNNTJoint.set_wer()` methods. Further, when this flag is set, the following argument `fused_batch_size` *must* be provided as a non negative integer. This value refers to the size of the sub-batch. When the flag is set, the input and output signature of `forward()` of this method changes. Input - in addition to `encoder_outputs` (mandatory argument), the following arguments can be provided. - decoder_outputs (optional). Required if loss computation is required. - encoder_lengths (required) - transcripts (optional). Required for wer calculation. - transcript_lengths (optional). Required for wer calculation. - compute_wer (bool, default false). Whether to compute WER or not for the fused batch. Output - instead of the usual `joint` log prob tensor, the following results can be returned. - loss (optional). Returned if decoder_outputs, transcripts and transript_lengths are not None. - wer_numerator + wer_denominator (optional). Returned if transcripts, transcripts_lengths are provided and compute_wer is set. fused_batch_size: Optional int, required if `fuse_loss_wer` flag is set. Determines the size of the sub-batches. Should be any value below the actual batch size per GPU. masking_prob: Optional float, indicating the probability of masking out decoder output in HAINAN (Hybrid Autoregressive Inference Transducer) model, described in https://arxiv.org/pdf/2410.02597 Default to -1.0, which runs standard Joint network computation; if > 0, then masking out decoder output with the specified probability. """ @property def input_types(self): """Returns definitions of module input ports.""" return { "encoder_outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "decoder_outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()), "encoder_lengths": NeuralType(tuple('B'), LengthsType(), optional=True), "transcripts": NeuralType(('B', 'T'), LabelsType(), optional=True), "transcript_lengths": NeuralType(tuple('B'), LengthsType(), optional=True), "compute_wer": NeuralType(optional=True), } @property def output_types(self): """Returns definitions of module output ports.""" if not self._fuse_loss_wer: return { "outputs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), } else: return { "loss": NeuralType(elements_type=LossType(), optional=True), "wer": NeuralType(elements_type=ElementType(), optional=True), "wer_numer": NeuralType(elements_type=ElementType(), optional=True), "wer_denom": NeuralType(elements_type=ElementType(), optional=True), } def _prepare_for_export(self, **kwargs): self._fuse_loss_wer = False self.log_softmax = False super()._prepare_for_export(**kwargs)
[docs] def input_example(self, max_batch=1, max_dim=8192): """ Generates input examples for tracing etc. Returns: A tuple of input examples. """ B, T, U = max_batch, max_dim, max_batch encoder_outputs = torch.randn(B, self.encoder_hidden, T).to(next(self.parameters()).device) decoder_outputs = torch.randn(B, self.pred_hidden, U).to(next(self.parameters()).device) return (encoder_outputs, decoder_outputs)
@property def disabled_deployment_input_names(self): """Implement this method to return a set of input names disabled for export""" return set(["encoder_lengths", "transcripts", "transcript_lengths", "compute_wer"]) def __init__( self, jointnet: Dict[str, Any], num_classes: int, num_extra_outputs: int = 0, vocabulary: Optional[List] = None, log_softmax: Optional[bool] = None, preserve_memory: bool = False, fuse_loss_wer: bool = False, fused_batch_size: Optional[int] = None, experimental_fuse_loss_wer: Any = None, masking_prob: float = -1.0, ): super().__init__() self.vocabulary = vocabulary self._vocab_size = num_classes self._num_extra_outputs = num_extra_outputs self._num_classes = num_classes + 1 + num_extra_outputs # 1 is for blank self.masking_prob = masking_prob if self.masking_prob > 0.0: assert self.masking_prob < 1.0, "masking_prob must be between 0 and 1" if experimental_fuse_loss_wer is not None: # Override fuse_loss_wer from deprecated argument fuse_loss_wer = experimental_fuse_loss_wer self._fuse_loss_wer = fuse_loss_wer self._fused_batch_size = fused_batch_size if fuse_loss_wer and (fused_batch_size is None): raise ValueError("If `fuse_loss_wer` is set, then `fused_batch_size` cannot be None!") self._loss = None self._wer = None # Log softmax should be applied explicitly only for CPU self.log_softmax = log_softmax self.preserve_memory = preserve_memory if preserve_memory: logging.warning( "`preserve_memory` was set for the Joint Model. Please be aware this will severely impact " "the forward-backward step time. It also might not solve OOM issues if the GPU simply " "does not have enough memory to compute the joint." ) # Required arguments self.encoder_hidden = jointnet['encoder_hidden'] self.pred_hidden = jointnet['pred_hidden'] self.joint_hidden = jointnet['joint_hidden'] self.activation = jointnet['activation'] # Optional arguments dropout = jointnet.get('dropout', 0.0) self.pred, self.enc, self.joint_net = self._joint_net_modules( num_classes=self._num_classes, # add 1 for blank symbol pred_n_hidden=self.pred_hidden, enc_n_hidden=self.encoder_hidden, joint_n_hidden=self.joint_hidden, activation=self.activation, dropout=dropout, ) # Flag needed for RNNT export support self._rnnt_export = False # to change, requires running ``model.temperature = T`` explicitly self.temperature = 1.0
[docs] @typecheck() def forward( self, encoder_outputs: torch.Tensor, decoder_outputs: Optional[torch.Tensor], encoder_lengths: Optional[torch.Tensor] = None, transcripts: Optional[torch.Tensor] = None, transcript_lengths: Optional[torch.Tensor] = None, compute_wer: bool = False, ) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]: # encoder = (B, D, T) # decoder = (B, D, U) if passed, else None encoder_outputs = encoder_outputs.transpose(1, 2) # (B, T, D) if decoder_outputs is not None: decoder_outputs = decoder_outputs.transpose(1, 2) # (B, U, D) if not self._fuse_loss_wer: if decoder_outputs is None: raise ValueError( "decoder_outputs passed is None, and `fuse_loss_wer` is not set. " "decoder_outputs can only be None for fused step!" ) out = self.joint(encoder_outputs, decoder_outputs) # [B, T, U, V + 1] return out else: # At least the loss module must be supplied during fused joint if self._loss is None or self._wer is None: raise ValueError("`fuse_loss_wer` flag is set, but `loss` and `wer` modules were not provided! ") # If fused joint step is required, fused batch size is required as well if self._fused_batch_size is None: raise ValueError("If `fuse_loss_wer` is set, then `fused_batch_size` cannot be None!") # When using fused joint step, both encoder and transcript lengths must be provided if (encoder_lengths is None) or (transcript_lengths is None): raise ValueError( "`fuse_loss_wer` is set, therefore encoder and target lengths " "must be provided as well!" ) losses = [] wers, wer_nums, wer_denoms = [], [], [] target_lengths = [] batch_size = int(encoder_outputs.size(0)) # actual batch size # Iterate over batch using fused_batch_size steps for batch_idx in range(0, batch_size, self._fused_batch_size): begin = batch_idx end = min(begin + self._fused_batch_size, batch_size) # Extract the sub batch inputs # sub_enc = encoder_outputs[begin:end, ...] # sub_transcripts = transcripts[begin:end, ...] sub_enc = encoder_outputs.narrow(dim=0, start=begin, length=int(end - begin)) sub_transcripts = transcripts.narrow(dim=0, start=begin, length=int(end - begin)) sub_enc_lens = encoder_lengths[begin:end] sub_transcript_lens = transcript_lengths[begin:end] # Sub transcripts does not need the full padding of the entire batch # Therefore reduce the decoder time steps to match max_sub_enc_length = sub_enc_lens.max() max_sub_transcript_length = sub_transcript_lens.max() if decoder_outputs is not None: # Reduce encoder length to preserve computation # Encoder: [sub-batch, T, D] -> [sub-batch, T', D]; T' < T if sub_enc.shape[1] != max_sub_enc_length: sub_enc = sub_enc.narrow(dim=1, start=0, length=int(max_sub_enc_length)) # sub_dec = decoder_outputs[begin:end, ...] # [sub-batch, U, D] sub_dec = decoder_outputs.narrow(dim=0, start=begin, length=int(end - begin)) # [sub-batch, U, D] # Reduce decoder length to preserve computation # Decoder: [sub-batch, U, D] -> [sub-batch, U', D]; U' < U if sub_dec.shape[1] != max_sub_transcript_length + 1: sub_dec = sub_dec.narrow(dim=1, start=0, length=int(max_sub_transcript_length + 1)) # Perform joint => [sub-batch, T', U', V + 1] sub_joint = self.joint(sub_enc, sub_dec) del sub_dec # Reduce transcript length to correct alignment # Transcript: [sub-batch, L] -> [sub-batch, L']; L' <= L if sub_transcripts.shape[1] != max_sub_transcript_length: sub_transcripts = sub_transcripts.narrow(dim=1, start=0, length=int(max_sub_transcript_length)) # Compute sub batch loss # preserve loss reduction type loss_reduction = self.loss.reduction # override loss reduction to sum self.loss.reduction = None # compute and preserve loss loss_batch = self.loss( log_probs=sub_joint, targets=sub_transcripts, input_lengths=sub_enc_lens, target_lengths=sub_transcript_lens, ) losses.append(loss_batch) target_lengths.append(sub_transcript_lens) # reset loss reduction type self.loss.reduction = loss_reduction else: losses = None # Update WER for sub batch if compute_wer: sub_enc = sub_enc.transpose(1, 2) # [B, T, D] -> [B, D, T] sub_enc = sub_enc.detach() sub_transcripts = sub_transcripts.detach() # Update WER on each process without syncing if self.training: original_sync = self.wer._to_sync self.wer._to_sync = False self.wer.update( predictions=sub_enc, predictions_lengths=sub_enc_lens, targets=sub_transcripts, targets_lengths=sub_transcript_lens, ) # Sync and all_reduce on all processes, compute global WER wer, wer_num, wer_denom = self.wer.compute() self.wer.reset() if self.training: self.wer._to_sync = original_sync wers.append(wer) wer_nums.append(wer_num) wer_denoms.append(wer_denom) del sub_enc, sub_transcripts, sub_enc_lens, sub_transcript_lens # Reduce over sub batches if losses is not None: losses = self.loss.reduce(losses, target_lengths) # Collect sub batch wer results if compute_wer: wer = sum(wers) / len(wers) wer_num = sum(wer_nums) wer_denom = sum(wer_denoms) else: wer = None wer_num = None wer_denom = None return losses, wer, wer_num, wer_denom
[docs] def project_encoder(self, encoder_output: torch.Tensor) -> torch.Tensor: """ Project the encoder output to the joint hidden dimension. Args: encoder_output: A torch.Tensor of shape [B, T, D] Returns: A torch.Tensor of shape [B, T, H] """ return self.enc(encoder_output)
[docs] def project_prednet(self, prednet_output: torch.Tensor) -> torch.Tensor: """ Project the Prediction Network (Decoder) output to the joint hidden dimension. Args: prednet_output: A torch.Tensor of shape [B, U, D] Returns: A torch.Tensor of shape [B, U, H] """ return self.pred(prednet_output)
[docs] def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: r""" Compute the joint step of the network after projection. Here, B = Batch size T = Acoustic model timesteps U = Target sequence length H1, H2 = Hidden dimensions of the Encoder / Decoder respectively H = Hidden dimension of the Joint hidden step. V = Vocabulary size of the Decoder (excluding the RNNT blank token). NOTE: The implementation of this model is slightly modified from the original paper. The original paper proposes the following steps : (enc, dec) -> Expand + Concat + Sum [B, T, U, H1+H2] -> Forward through joint hidden [B, T, U, H] -- \*1 \*1 -> Forward through joint final [B, T, U, V + 1]. We instead split the joint hidden into joint_hidden_enc and joint_hidden_dec and act as follows: enc -> Forward through joint_hidden_enc -> Expand [B, T, 1, H] -- \*1 dec -> Forward through joint_hidden_dec -> Expand [B, 1, U, H] -- \*2 (\*1, \*2) -> Sum [B, T, U, H] -> Forward through joint final [B, T, U, V + 1]. Args: f: Output of the Encoder model. A torch.Tensor of shape [B, T, H1] g: Output of the Decoder model. A torch.Tensor of shape [B, U, H2] Returns: Logits / log softmaxed tensor of shape (B, T, U, V + 1). """ f = f.unsqueeze(dim=2) # (B, T, 1, H) g = g.unsqueeze(dim=1) # (B, 1, U, H) if self.training and self.masking_prob > 0: [B, _, U, _] = g.shape rand = torch.rand([B, 1, U, 1]).to(g.device) rand = torch.gt(rand, self.masking_prob) g = g * rand inp = f + g # [B, T, U, H] del f, g # Forward adapter modules on joint hidden if self.is_adapter_available(): inp = self.forward_enabled_adapters(inp) res = self.joint_net(inp) # [B, T, U, V + 1] del inp if self.preserve_memory: torch.cuda.empty_cache() # If log_softmax is automatic if self.log_softmax is None: if not res.is_cuda: # Use log softmax only if on CPU if self.temperature != 1.0: res = (res / self.temperature).log_softmax(dim=-1) else: res = res.log_softmax(dim=-1) else: if self.log_softmax: if self.temperature != 1.0: res = (res / self.temperature).log_softmax(dim=-1) else: res = res.log_softmax(dim=-1) return res
def _joint_net_modules(self, num_classes, pred_n_hidden, enc_n_hidden, joint_n_hidden, activation, dropout): """ Prepare the trainable modules of the Joint Network Args: num_classes: Number of output classes (vocab size) excluding the RNNT blank token. pred_n_hidden: Hidden size of the prediction network. enc_n_hidden: Hidden size of the encoder network. joint_n_hidden: Hidden size of the joint network. activation: Activation of the joint. Can be one of [relu, tanh, sigmoid] dropout: Dropout value to apply to joint. """ pred = torch.nn.Linear(pred_n_hidden, joint_n_hidden) enc = torch.nn.Linear(enc_n_hidden, joint_n_hidden) if activation not in ['relu', 'sigmoid', 'tanh']: raise ValueError("Unsupported activation for joint step - please pass one of " "[relu, sigmoid, tanh]") activation = activation.lower() if activation == 'relu': activation = torch.nn.ReLU(inplace=True) elif activation == 'sigmoid': activation = torch.nn.Sigmoid() elif activation == 'tanh': activation = torch.nn.Tanh() layers = ( [activation] + ([torch.nn.Dropout(p=dropout)] if dropout else []) + [torch.nn.Linear(joint_n_hidden, num_classes)] ) return pred, enc, torch.nn.Sequential(*layers) # Adapter method overrides
[docs] def add_adapter(self, name: str, cfg: DictConfig): # Update the config with correct input dim cfg = self._update_adapter_cfg_input_dim(cfg) # Add the adapter super().add_adapter(name=name, cfg=cfg)
def _update_adapter_cfg_input_dim(self, cfg: DictConfig): cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.joint_hidden) return cfg @property def num_classes_with_blank(self): return self._num_classes @property def num_extra_outputs(self): return self._num_extra_outputs @property def loss(self): return self._loss
[docs] def set_loss(self, loss): if not self._fuse_loss_wer: raise ValueError("Attempting to set loss module even though `fuse_loss_wer` is not set!") self._loss = loss
@property def wer(self): return self._wer
[docs] def set_wer(self, wer): if not self._fuse_loss_wer: raise ValueError("Attempting to set WER module even though `fuse_loss_wer` is not set!") self._wer = wer
@property def fuse_loss_wer(self): return self._fuse_loss_wer
[docs] def set_fuse_loss_wer(self, fuse_loss_wer, loss=None, metric=None): self._fuse_loss_wer = fuse_loss_wer self._loss = loss self._wer = metric
@property def fused_batch_size(self): return self._fused_batch_size
[docs] def set_fused_batch_size(self, fused_batch_size): self._fused_batch_size = fused_batch_size
class RNNTDecoderJoint(torch.nn.Module, Exportable): """ Utility class to export Decoder+Joint as a single module """ def __init__(self, decoder, joint): super().__init__() self.decoder = decoder self.joint = joint @property def input_types(self): state_type = NeuralType(('D', 'B', 'D'), ElementType()) mytypes = { 'encoder_outputs': NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "targets": NeuralType(('B', 'T'), LabelsType()), "target_length": NeuralType(tuple('B'), LengthsType()), 'input_states_1': state_type, 'input_states_2': state_type, } return mytypes def input_example(self, max_batch=1, max_dim=1): decoder_example = self.decoder.input_example(max_batch=max_batch, max_dim=max_dim) state1, state2 = decoder_example[-1] return tuple([self.joint.input_example()[0]]) + decoder_example[:2] + (state1, state2) @property def output_types(self): return { "outputs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), "prednet_lengths": NeuralType(tuple('B'), LengthsType()), "output_states_1": NeuralType((('D', 'B', 'D')), ElementType()), "output_states_2": NeuralType((('D', 'B', 'D')), ElementType()), } def forward(self, encoder_outputs, targets, target_length, input_states_1, input_states_2): decoder_outputs = self.decoder(targets, target_length, (input_states_1, input_states_2)) decoder_output = decoder_outputs[0] decoder_length = decoder_outputs[1] input_states_1, input_states_2 = decoder_outputs[2][0], decoder_outputs[2][1] joint_output = self.joint(encoder_outputs, decoder_output) return (joint_output, decoder_length, input_states_1, input_states_2) class RNNTDecoderJointSSL(torch.nn.Module): def __init__(self, decoder, joint): super().__init__() self.decoder = decoder self.joint = joint @property def needs_labels(self): return True @property def input_types(self): return { "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "targets": NeuralType(('B', 'T'), LabelsType()), "target_lengths": NeuralType(tuple('B'), LengthsType()), } @property def output_types(self): return {"log_probs": NeuralType(('B', 'T', 'D'), SpectrogramType())} def forward(self, encoder_output, targets, target_lengths): decoder, target_length, states = self.decoder(targets=targets, target_length=target_lengths) log_probs = self.joint(encoder_outputs=encoder_output, decoder_outputs=decoder) return log_probs
[docs] class SampledRNNTJoint(RNNTJoint): """A Sampled Recurrent Neural Network Transducer Joint Network (RNN-T Joint Network). An RNN-T Joint network, comprised of a feedforward model, where the vocab size will be sampled instead of computing the full vocabulary joint. Args: jointnet: A dict-like object which contains the following key-value pairs. encoder_hidden: int specifying the hidden dimension of the encoder net. pred_hidden: int specifying the hidden dimension of the prediction net. joint_hidden: int specifying the hidden dimension of the joint net activation: Activation function used in the joint step. Can be one of ['relu', 'tanh', 'sigmoid']. Optionally, it may also contain the following: dropout: float, set to 0.0 by default. Optional dropout applied at the end of the joint net. num_classes: int, specifying the vocabulary size that the joint network must predict, excluding the RNNT blank token. n_samples: int, specifies the number of tokens to sample from the vocabulary space, excluding the RNNT blank token. If a given value is larger than the entire vocabulary size, then the full vocabulary will be used. vocabulary: Optional list of strings/tokens that comprise the vocabulary of the joint network. Unused and kept only for easy access for character based encoding RNNT models. log_softmax: Optional bool, set to None by default. If set as None, will compute the log_softmax() based on the value provided. preserve_memory: Optional bool, set to False by default. If the model crashes due to the memory intensive joint step, one might try this flag to empty the tensor cache in pytorch. Warning: This will make the forward-backward pass much slower than normal. It also might not fix the OOM if the GPU simply does not have enough memory to compute the joint. fuse_loss_wer: Optional bool, set to False by default. Fuses the joint forward, loss forward and wer forward steps. In doing so, it trades of speed for memory conservation by creating sub-batches of the provided batch of inputs, and performs Joint forward, loss forward and wer forward (optional), all on sub-batches, then collates results to be exactly equal to results from the entire batch. When this flag is set, prior to calling forward, the fields `loss` and `wer` (either one) *must* be set using the `RNNTJoint.set_loss()` or `RNNTJoint.set_wer()` methods. Further, when this flag is set, the following argument `fused_batch_size` *must* be provided as a non negative integer. This value refers to the size of the sub-batch. When the flag is set, the input and output signature of `forward()` of this method changes. Input - in addition to `encoder_outputs` (mandatory argument), the following arguments can be provided. - decoder_outputs (optional). Required if loss computation is required. - encoder_lengths (required) - transcripts (optional). Required for wer calculation. - transcript_lengths (optional). Required for wer calculation. - compute_wer (bool, default false). Whether to compute WER or not for the fused batch. Output - instead of the usual `joint` log prob tensor, the following results can be returned. - loss (optional). Returned if decoder_outputs, transcripts and transript_lengths are not None. - wer_numerator + wer_denominator (optional). Returned if transcripts, transcripts_lengths are provided and compute_wer is set. fused_batch_size: Optional int, required if `fuse_loss_wer` flag is set. Determines the size of the sub-batches. Should be any value below the actual batch size per GPU. """ def __init__( self, jointnet: Dict[str, Any], num_classes: int, n_samples: int, vocabulary: Optional[List] = None, log_softmax: Optional[bool] = None, preserve_memory: bool = False, fuse_loss_wer: bool = False, fused_batch_size: Optional[int] = None, ): super().__init__( jointnet=jointnet, num_classes=num_classes, vocabulary=vocabulary, log_softmax=log_softmax, preserve_memory=preserve_memory, fuse_loss_wer=fuse_loss_wer, fused_batch_size=fused_batch_size, ) self.n_samples = n_samples self.register_buffer('blank_id', torch.tensor([self.num_classes_with_blank - 1]), persistent=False)
[docs] @typecheck() def forward( self, encoder_outputs: torch.Tensor, decoder_outputs: Optional[torch.Tensor], encoder_lengths: Optional[torch.Tensor] = None, transcripts: Optional[torch.Tensor] = None, transcript_lengths: Optional[torch.Tensor] = None, compute_wer: bool = False, ) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]: # If in inference mode, revert to basic RNNT Joint behaviour. # Sampled RNNT is only used for training. if not torch.is_grad_enabled() or torch.is_inference_mode_enabled(): # Simply call full tensor joint return super().forward( encoder_outputs=encoder_outputs, decoder_outputs=decoder_outputs, encoder_lengths=encoder_lengths, transcripts=transcripts, transcript_lengths=transcript_lengths, compute_wer=compute_wer, ) if transcripts is None or transcript_lengths is None: logging.warning( "Sampled RNNT Joint currently only works with `fuse_loss_wer` set to True, " "and when `fused_batch_size` is a positive integer." ) raise ValueError( "Sampled RNNT loss only works when the transcripts are provided during training." "Please ensure that you correctly pass the `transcripts` and `transcript_lengths`." ) # encoder = (B, D, T) # decoder = (B, D, U) if passed, else None encoder_outputs = encoder_outputs.transpose(1, 2) # (B, T, D) if decoder_outputs is not None: decoder_outputs = decoder_outputs.transpose(1, 2) # (B, U, D) # At least the loss module must be supplied during fused joint if self._loss is None or self._wer is None: raise ValueError("`fuse_loss_wer` flag is set, but `loss` and `wer` modules were not provided! ") # If fused joint step is required, fused batch size is required as well if self._fused_batch_size is None: raise ValueError("If `fuse_loss_wer` is set, then `fused_batch_size` cannot be None!") # When using fused joint step, both encoder and transcript lengths must be provided if (encoder_lengths is None) or (transcript_lengths is None): raise ValueError( "`fuse_loss_wer` is set, therefore encoder and target lengths " "must be provided as well!" ) losses = [] wers, wer_nums, wer_denoms = [], [], [] target_lengths = [] batch_size = int(encoder_outputs.size(0)) # actual batch size # Iterate over batch using fused_batch_size steps for batch_idx in range(0, batch_size, self._fused_batch_size): begin = batch_idx end = min(begin + self._fused_batch_size, batch_size) # Extract the sub batch inputs # sub_enc = encoder_outputs[begin:end, ...] # sub_transcripts = transcripts[begin:end, ...] sub_enc = encoder_outputs.narrow(dim=0, start=begin, length=int(end - begin)) sub_transcripts = transcripts.narrow(dim=0, start=begin, length=int(end - begin)) sub_enc_lens = encoder_lengths[begin:end] sub_transcript_lens = transcript_lengths[begin:end] # Sub transcripts does not need the full padding of the entire batch # Therefore reduce the decoder time steps to match max_sub_enc_length = sub_enc_lens.max() max_sub_transcript_length = sub_transcript_lens.max() if decoder_outputs is not None: # Reduce encoder length to preserve computation # Encoder: [sub-batch, T, D] -> [sub-batch, T', D]; T' < T if sub_enc.shape[1] != max_sub_enc_length: sub_enc = sub_enc.narrow(dim=1, start=0, length=int(max_sub_enc_length)) # sub_dec = decoder_outputs[begin:end, ...] # [sub-batch, U, D] sub_dec = decoder_outputs.narrow(dim=0, start=begin, length=int(end - begin)) # [sub-batch, U, D] # Reduce decoder length to preserve computation # Decoder: [sub-batch, U, D] -> [sub-batch, U', D]; U' < U if sub_dec.shape[1] != max_sub_transcript_length + 1: sub_dec = sub_dec.narrow(dim=1, start=0, length=int(max_sub_transcript_length + 1)) # Reduce transcript length to correct alignment # Transcript: [sub-batch, L] -> [sub-batch, L']; L' <= L if sub_transcripts.shape[1] != max_sub_transcript_length: sub_transcripts = sub_transcripts.narrow(dim=1, start=0, length=int(max_sub_transcript_length)) # Perform sampled joint => [sub-batch, T', U', {V' < V} + 1}] sub_joint, sub_transcripts_remapped = self.sampled_joint( sub_enc, sub_dec, transcript=sub_transcripts, transcript_lengths=sub_transcript_lens ) del sub_dec # Compute sub batch loss # preserve loss reduction type loss_reduction = self.loss.reduction # override loss reduction to sum self.loss.reduction = None # override blank idx in order to map to new vocabulary space # in the new vocabulary space, we set the mapping of the RNNT Blank from index V+1 to 0 # So the loss here needs to be updated accordingly. # TODO: See if we can have some formal API for rnnt loss to update inner blank index. cached_blank_id = self.loss._loss.blank self.loss._loss.blank = 0 # compute and preserve loss loss_batch = self.loss( log_probs=sub_joint, targets=sub_transcripts_remapped, # Note: We have to use remapped transcripts here ! input_lengths=sub_enc_lens, target_lengths=sub_transcript_lens, # Note: Even after remap, the transcript lengths remain intact. ) losses.append(loss_batch) target_lengths.append(sub_transcript_lens) # reset loss reduction type and blank id self.loss.reduction = loss_reduction self.loss._loss.blank = cached_blank_id else: losses = None # Update WER for sub batch if compute_wer: sub_enc = sub_enc.transpose(1, 2) # [B, T, D] -> [B, D, T] sub_enc = sub_enc.detach() sub_transcripts = sub_transcripts.detach() # Update WER on each process without syncing self.wer.update( predictions=sub_enc, predictions_lengths=sub_enc_lens, targets=sub_transcripts, targets_lengths=sub_transcript_lens, ) # Sync and all_reduce on all processes, compute global WER wer, wer_num, wer_denom = self.wer.compute() self.wer.reset() wers.append(wer) wer_nums.append(wer_num) wer_denoms.append(wer_denom) del sub_enc, sub_transcripts, sub_enc_lens, sub_transcript_lens # Reduce over sub batches if losses is not None: losses = self.loss.reduce(losses, target_lengths) # Collect sub batch wer results if compute_wer: wer = sum(wers) / len(wers) wer_num = sum(wer_nums) wer_denom = sum(wer_denoms) else: wer = None wer_num = None wer_denom = None return losses, wer, wer_num, wer_denom
[docs] def sampled_joint( self, f: torch.Tensor, g: torch.Tensor, transcript: torch.Tensor, transcript_lengths: torch.Tensor, ) -> torch.Tensor: r""" Compute the sampled joint step of the network. Reference: `Memory-Efficient Training of RNN-Transducer with Sampled Softmax <https://arxiv.org/abs/2203.16868>`__. Here, B = Batch size T = Acoustic model timesteps U = Target sequence length H1, H2 = Hidden dimensions of the Encoder / Decoder respectively H = Hidden dimension of the Joint hidden step. V = Vocabulary size of the Decoder (excluding the RNNT blank token). S = Sample size of vocabulary. NOTE: The implementation of this joint model is slightly modified from the original paper. The original paper proposes the following steps : (enc, dec) -> Expand + Concat + Sum [B, T, U, H1+H2] -> Forward through joint hidden [B, T, U, H] -- \*1 \*1 -> Forward through joint final [B, T, U, V + 1]. We instead split the joint hidden into joint_hidden_enc and joint_hidden_dec and act as follows: enc -> Forward through joint_hidden_enc -> Expand [B, T, 1, H] -- \*1 dec -> Forward through joint_hidden_dec -> Expand [B, 1, U, H] -- \*2 (\*1, \*2) -> Sum [B, T, U, H] -> Sample Vocab V_Pos (for target tokens) and V_Neg -> (V_Neg is sampled not uniformly by as a rand permutation of all vocab tokens, then eliminate all Intersection(V_Pos, V_Neg) common tokens to avoid duplication of loss) -> Concat new Vocab V_Sampled = Union(V_Pos, V_Neg) -> Forward partially through the joint final to create [B, T, U, V_Sampled] Args: f: Output of the Encoder model. A torch.Tensor of shape [B, T, H1] g: Output of the Decoder model. A torch.Tensor of shape [B, U, H2] transcript: Batch of transcripts. A torch.Tensor of shape [B, U] transcript_lengths: Batch of lengths of the transcripts. A torch.Tensor of shape [B] Returns: Logits / log softmaxed tensor of shape (B, T, U, V + 1). """ # If under inference mode, ignore sampled joint and compute full joint. if self.training is False or torch.is_grad_enabled() is False or torch.is_inference_mode_enabled(): # Simply call full tensor joint return super().joint(f=f, g=g) # Compute sampled softmax # f = [B, T, H1] f = self.enc(f) f.unsqueeze_(dim=2) # (B, T, 1, H) # g = [B, U, H2] g = self.pred(g) g.unsqueeze_(dim=1) # (B, 1, U, H) inp = f + g # [B, T, U, H] del f, g # Forward adapter modules on joint hidden if self.is_adapter_available(): inp = self.forward_enabled_adapters(inp) # Do partial forward of joint net (skipping the final linear) for module in self.joint_net[:-1]: inp = module(inp) # [B, T, U, H] # Begin compute of sampled RNNT joint with torch.no_grad(): # gather true labels transcript_vocab_ids = torch.unique(transcript) # augment with blank token id transcript_vocab_ids = torch.cat([self.blank_id, transcript_vocab_ids]) # Remap the transcript label ids to new positions of label ids (in the transcript_vocab_ids) # This is necessary cause the RNNT loss doesnt care about the value, only the position of the ids # of the transcript tokens. We can skip this step for noise samples cause those are only used for softmax # estimation, not for computing actual label. # From `https://stackoverflow.com/a/68969697` - bucketize algo. t_ids = torch.arange(transcript_vocab_ids.size(0), device='cpu') mapping = {k: v for k, v in zip(transcript_vocab_ids.to('cpu'), t_ids)} # From `https://stackoverflow.com/questions/13572448`. palette, key = zip(*mapping.items()) t_device = transcript.device key = torch.tensor(key, device=t_device) palette = torch.tensor(palette, device=t_device) # This step maps old token id to new token id in broadcasted manner. # For example, if original transcript tokens were [2, 1, 4, 5, 4, 1] # But after computing the unique token set of above we get # transcript_vocab_ids = [1, 2, 4, 5] # note: pytorch returns sorted unique values thankfully # Then we get the index map of the new vocab ids as: # {0: 1, 1: 2, 2: 4, 3: 5} # Now we need to map the original transcript tokens to new vocab id space # So we construct the inverted map as follow : # {1: 0, 2: 1, 4: 2, 5: 3} # Then remap the original transcript tokens to new token ids # new_transcript = [1, 0, 2, 3, 2, 0] index = torch.bucketize(transcript.ravel(), palette) transcript = key[index].reshape(transcript.shape) transcript = transcript.to(t_device) # Extract out partial weight tensor and bias tensor of just the V_Pos vocabulary from the full joint. true_weights = self.joint_net[-1].weight[transcript_vocab_ids, :] true_bias = self.joint_net[-1].bias[transcript_vocab_ids] # Compute the transcript joint scores (only of vocab V_Pos) transcript_scores = torch.matmul(inp, true_weights.transpose(0, 1)) + true_bias # Construct acceptance criteria in vocab space, reject all tokens in Intersection(V_Pos, V_Neg) with torch.no_grad(): # Instead of uniform sample, first we create arange V (ignoring blank), then randomly shuffle # this range of ids, then subset `n_samples` amount of vocab tokens out of the permuted tensor. # This is good because it guarentees that no token will ever be repeated in V_Neg; # which dramatically complicates loss calculation. # Further more, with this strategy, given a `n_samples` > V + 1; we are guarenteed to get the # V_Samples = V (i.e., full vocabulary will be used in such a case). # Useful to debug cases where you expect sampled vocab to get exact same training curve as # full vocab. sample_ids = torch.randperm(n=self.num_classes_with_blank - 1, device=transcript_scores.device)[ : self.n_samples ] # We need to compute the intersection(V_Pos, V_Neg), then eliminate the intersection arguments # from inside V_Neg. # First, compute the pairwise commonality to find index inside `sample_ids` which match the token id # inside transcript_vocab_ids. # Note: It is important to ignore the hardcoded RNNT Blank token injected at id 0 of the transcript # vocab ids, otherwise the blank may occur twice, once for RNNT blank and once as negative sample, # doubling the gradient of the RNNT blank token. reject_samples = torch.where(transcript_vocab_ids[1:, None] == sample_ids[None, :]) # Let accept samples be a set of ids which is a subset of sample_ids # such that intersection(V_Pos, accept_samples) is a null set. accept_samples = sample_ids.clone() # In order to construct such an accept_samples tensor, first we construct a bool map # and fill all the indices where there is a match inside of sample_ids. # reject_samples is a tuple (transcript_vocab_position, sample_position) which gives a # many to many map between N values of transript and M values of sample_ids. # We dont care about transcript side matches, only the ids inside of sample_ids that matched. sample_mask = torch.ones_like(accept_samples, dtype=torch.bool) sample_mask[reject_samples[1]] = False # Finally, compute the subset of tokens by selecting only those sample_ids which had no matches accept_samples = accept_samples[sample_mask] # Extract out partial weight tensor and bias tensor of just the V_Neg vocabulary from the full joint. sample_weights = self.joint_net[-1].weight[accept_samples, :] sample_bias = self.joint_net[-1].bias[accept_samples] # Compute the noise joint scores (only of vocab V_Neg) to be used for softmax # The quality of this sample determines the quality of the softmax gradient. # We use naive algo broadcasted over batch, but it is more efficient than sample level computation. # One can increase `n_samples` for better estimation of rejection samples and its gradient. noise_scores = torch.matmul(inp, sample_weights.transpose(0, 1)) + sample_bias # Finally, construct the sampled joint as the V_Sampled = Union(V_Pos, V_Neg) # Here, we simply concatenate the two tensors to construct the joint with V_Sampled vocab # because before we have properly asserted that Intersection(V_Pos, V_Neg) is a null set. res = torch.cat([transcript_scores, noise_scores], dim=-1) del inp if self.preserve_memory: torch.cuda.empty_cache() # If log_softmax is automatic if self.log_softmax is None: if not res.is_cuda: # Use log softmax only if on CPU res = res.log_softmax(dim=-1) else: if self.log_softmax: res = res.log_softmax(dim=-1) return res, transcript
# Add the adapter compatible modules to the registry for cls in [RNNTDecoder, RNNTJoint, SampledRNNTJoint]: if adapter_mixins.get_registered_adapter(cls) is None: adapter_mixins.register_adapter(cls, cls) # base class is adapter compatible itself