Source code for sionna.nr.tb_decoder

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Transport block decoding functions for the 5g NR sub-package of Sionna.
"""

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer
from sionna.fec.crc import CRCDecoder
from sionna.fec.scrambling import  Descrambler
from sionna.fec.ldpc import LDPC5GDecoder
from sionna.nr import TBEncoder

[docs] class TBDecoder(Layer): # pylint: disable=line-too-long r"""TBDecoder(encoder, num_bp_iter=20, cn_type="boxplus-phi", output_dtype=tf.float32, **kwargs) 5G NR transport block (TB) decoder as defined in TS 38.214 [3GPP38214]_. The transport block decoder takes as input a sequence of noisy channel observations and reconstructs the corresponding `transport block` of information bits. The detailed procedure is described in TS 38.214 [3GPP38214]_ and TS 38.211 [3GPP38211]_. The class inherits from the Keras layer class and can be used as layer in a Keras model. Parameters ---------- encoder : :class:`~sionna.nr.TBEncoder` Associated transport block encoder used for encoding of the signal. num_bp_iter : int, 20 (default) Number of BP decoder iterations cn_type : str, "boxplus-phi" (default) | "boxplus" | "minsum" The check node processing function of the LDPC BP decoder. One of {`"boxplus"`, `"boxplus-phi"`, `"minsum"`} where '"boxplus"' implements the single-parity-check APP decoding rule. '"boxplus-phi"' implements the numerical more stable version of boxplus [Ryan]_. '"minsum"' implements the min-approximation of the CN update rule [Ryan]_. output_dtype : tf.float32 (default) Defines the datatype for internal calculations and the output dtype. Input ----- inputs : [...,num_coded_bits], tf.float 2+D tensor containing channel logits/llr values of the (noisy) channel observations. Output ------ b_hat : [...,target_tb_size], tf.float 2+D tensor containing hard decided bit estimates of all information bits of the transport block. tb_crc_status : [...], tf.bool Transport block CRC status indicating if a transport block was (most likely) correctly recovered. Note that false positives are possible. """ def __init__(self, encoder, num_bp_iter=20, cn_type="boxplus-phi", output_dtype=tf.float32, **kwargs): super().__init__(dtype=output_dtype, **kwargs) assert output_dtype in (tf.float16, tf.float32, tf.float64), \ "output_dtype must be (tf.float16, tf.float32, tf.float64)." assert isinstance(encoder, TBEncoder), "encoder must be TBEncoder." self._tb_encoder = encoder self._num_cbs = encoder.num_cbs # init BP decoder self._decoder = LDPC5GDecoder(encoder=encoder.ldpc_encoder, num_iter=num_bp_iter, cn_type=cn_type, hard_out=True, # TB operates on bit-level return_infobits=True, output_dtype=output_dtype) # init descrambler if encoder.scrambler is not None: self._descrambler = Descrambler(encoder.scrambler, binary=False) else: self._descrambler = None # init CRC Decoder for CB and TB self._tb_crc_decoder = CRCDecoder(encoder.tb_crc_encoder) if encoder.cb_crc_encoder is not None: self._cb_crc_decoder = CRCDecoder(encoder.cb_crc_encoder) else: self._cb_crc_decoder = None ######################################### # Public methods and properties ######################################### @property def tb_size(self): """Number of information bits per TB.""" return self._tb_encoder.tb_size # required for @property def k(self): """Number of input information bits. Equals TB size.""" return self._tb_encoder.tb_size @property def n(self): "Total number of output codeword bits." return self._tb_encoder.n ######################### # Keras layer functions ######################### def build(self, input_shapes): """Test input shapes for consistency.""" assert input_shapes[-1]==self.n, \ f"Invalid input shape. Expected input length is {self.n}." def call(self, inputs): """Apply transport block decoding.""" # store shapes input_shape = inputs.shape.as_list() llr_ch = tf.cast(inputs, tf.float32) llr_ch = tf.reshape(llr_ch, (-1, self._tb_encoder.num_tx, self._tb_encoder.n)) # undo scrambling (only if scrambler was used) if self._descrambler is not None: llr_scr = self._descrambler(llr_ch) else: llr_scr = llr_ch # undo CB interleaving and puncturing num_fillers = self._tb_encoder.ldpc_encoder.n * self._tb_encoder.num_cbs - np.sum(self._tb_encoder.cw_lengths) llr_int = tf.concat([llr_scr, tf.zeros([tf.shape(llr_scr)[0], self._tb_encoder.num_tx, num_fillers])], axis=-1) llr_int = tf.gather(llr_int, self._tb_encoder.output_perm_inv, axis=-1) # undo CB concatenation llr_cb = tf.reshape(llr_int, (-1, self._tb_encoder.num_tx, self._num_cbs, self._tb_encoder.ldpc_encoder.n)) # LDPC decoding u_hat_cb = self._decoder(llr_cb) # CB CRC removal (if relevant) if self._cb_crc_decoder is not None: # we are ignoring the CB CRC status for the moment # Could be combined with the TB CRC for even better estimates u_hat_cb_crc, _ = self._cb_crc_decoder(u_hat_cb) else: u_hat_cb_crc = u_hat_cb # undo CB segmentation u_hat_tb = tf.reshape(u_hat_cb_crc, (-1, self._tb_encoder.num_tx, self.tb_size+self._tb_encoder.tb_crc_encoder.crc_length)) # TB CRC removal u_hat, tb_crc_status = self._tb_crc_decoder(u_hat_tb) # restore input shape output_shape = input_shape output_shape[0] = -1 output_shape[-1] = self.tb_size u_hat = tf.reshape(u_hat, output_shape) # also apply to tb_crc_status output_shape[-1] = 1 # but last dim is 1 tb_crc_status = tf.reshape(tb_crc_status, output_shape) # remove if zero-padding was applied if self._tb_encoder.k_padding>0: u_hat = u_hat[...,:-self._tb_encoder.k_padding] # cast to output dtype u_hat = tf.cast(u_hat, self.dtype) tb_crc_status = tf.squeeze(tf.cast(tb_crc_status, tf.bool), axis=-1) return u_hat, tb_crc_status