#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Transport block encoding functions for the 5g NR sub-package of Sionna.
"""
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer
from sionna.fec.crc import CRCEncoder
from sionna.fec.scrambling import TB5GScrambler
from sionna.fec.ldpc import LDPC5GEncoder
from sionna.nr.utils import calculate_tb_size
[docs]
class TBEncoder(Layer):
# pylint: disable=line-too-long
r"""TBEncoder(target_tb_size,num_coded_bits,target_coderate,num_bits_per_symbol,num_layers=1,n_rnti=1,n_id=1,channel_type="PUSCH",codeword_index=0,use_scrambler=True,verbose=False,output_dtype=tf.float32,, **kwargs)
5G NR transport block (TB) encoder as defined in TS 38.214
[3GPP38214]_ and TS 38.211 [3GPP38211]_
The transport block (TB) encoder takes as input a `transport block` of
information bits and generates a sequence of codewords for transmission.
For this, the information bit sequence is segmented into multiple codewords,
protected by additional CRC checks and FEC encoded. Further, interleaving
and scrambling is applied before a codeword concatenation generates the
final bit sequence. Fig. 1 provides an overview of the TB encoding
procedure and we refer the interested reader to [3GPP38214]_ and
[3GPP38211]_ for further details.
.. figure:: ../figures/tb_encoding.png
Fig. 1: Overview TB encoding (CB CRC does not always apply).
If ``n_rnti`` and ``n_id`` are given as list, the TBEncoder encodes
`num_tx = len(` ``n_rnti`` `)` parallel input streams with different
scrambling sequences per user.
The class inherits from the Keras layer class and can be used as layer in a
Keras model.
Parameters
----------
target_tb_size: int
Target transport block size, i.e., how many information bits are
encoded into the TB. Note that the effective TB size can be
slightly different due to quantization. If required, zero padding
is internally applied.
num_coded_bits: int
Number of coded bits after TB encoding.
target_coderate : float
Target coderate.
num_bits_per_symbol: int
Modulation order, i.e., number of bits per QAM symbol.
num_layers: int, 1 (default) | [1,...,8]
Number of transmission layers.
n_rnti: int or list of ints, 1 (default) | [0,...,65335]
RNTI identifier provided by higher layer. Defaults to 1 and must be
in range `[0, 65335]`. Defines a part of the random seed of the
scrambler. If provided as list, every list entry defines the RNTI
of an independent input stream.
n_id: int or list of ints, 1 (default) | [0,...,1023]
Data scrambling ID :math:`n_\text{ID}` related to cell id and
provided by higher layer.
Defaults to 1 and must be in range `[0, 1023]`. If provided as
list, every list entry defines the scrambling id of an independent
input stream.
channel_type: str, "PUSCH" (default) | "PDSCH"
Can be either "PUSCH" or "PDSCH".
codeword_index: int, 0 (default) | 1
Scrambler can be configured for two codeword transmission.
``codeword_index`` can be either 0 or 1. Must be 0 for
``channel_type`` = "PUSCH".
use_scrambler: bool, True (default)
If False, no data scrambling is applied (non standard-compliant).
verbose: bool, False (default)
If `True`, additional parameters are printed during initialization.
dtype: tf.float32 (default)
Defines the datatype for internal calculations and the output dtype.
Input
-----
inputs: [...,target_tb_size] or [...,num_tx,target_tb_size], tf.float
2+D tensor containing the information bits to be encoded. If
``n_rnti`` and ``n_id`` are a list of size `num_tx`, the input must
be of shape `[...,num_tx,target_tb_size]`.
Output
------
: [...,num_coded_bits], tf.float
2+D tensor containing the sequence of the encoded codeword bits of
the transport block.
Note
----
The parameters ``tb_size`` and ``num_coded_bits`` can be derived by the
:meth:`~sionna.nr.calculate_tb_size` function or
by accessing the corresponding :class:`~sionna.nr.PUSCHConfig` attributes.
"""
def __init__(self,
target_tb_size,
num_coded_bits,
target_coderate,
num_bits_per_symbol,
num_layers=1,
n_rnti=1,
n_id=1,
channel_type="PUSCH",
codeword_index=0,
use_scrambler=True,
verbose=False,
output_dtype=tf.float32,
**kwargs):
super().__init__(dtype=output_dtype, **kwargs)
assert isinstance(use_scrambler, bool), \
"use_scrambler must be bool."
self._use_scrambler = use_scrambler
assert isinstance(verbose, bool), \
"verbose must be bool."
self._verbose = verbose
# check input for consistency
assert channel_type in ("PDSCH", "PUSCH"), \
"Unsupported channel_type."
self._channel_type = channel_type
assert(target_tb_size%1==0), "target_tb_size must be int."
self._target_tb_size = int(target_tb_size)
assert(num_coded_bits%1==0), "num_coded_bits must be int."
self._num_coded_bits = int(num_coded_bits)
assert(0.<target_coderate <= 948/1024), \
"target_coderate must be in range(0,0.925)."
self._target_coderate = target_coderate
assert(num_bits_per_symbol%1==0), "num_bits_per_symbol must be int."
self._num_bits_per_symbol = int(num_bits_per_symbol)
assert(num_layers%1==0), "num_layers must be int."
self._num_layers = int(num_layers)
if channel_type=="PDSCH":
assert(codeword_index in (0,1)), "codeword_index must be 0 or 1."
else:
assert codeword_index==0, 'codeword_index must be 0 for "PUSCH".'
self._codeword_index = int(codeword_index)
if isinstance(n_rnti, (list, tuple)):
assert isinstance(n_id, (list, tuple)), "n_id must be also a list."
assert (len(n_rnti)==len(n_id)), \
"n_id and n_rnti must be of same length."
self._n_rnti = n_rnti
self._n_id = n_id
else:
self._n_rnti = [n_rnti]
self._n_id = [n_id]
for idx, n in enumerate(self._n_rnti):
assert(n%1==0), "n_rnti must be int."
self._n_rnti[idx] = int(n)
for idx, n in enumerate(self._n_id):
assert(n%1==0), "n_id must be int."
self._n_id[idx] = int(n)
self._num_tx = len(self._n_id)
tbconfig = calculate_tb_size(target_tb_size=self._target_tb_size,
num_coded_bits=self._num_coded_bits,
target_coderate=self._target_coderate,
modulation_order=self._num_bits_per_symbol,
num_layers=self._num_layers,
verbose=verbose)
self._tb_size = tbconfig[0]
self._cb_size = tbconfig[1]
self._num_cbs = tbconfig[2]
self._cw_lengths = tbconfig[3]
self._tb_crc_length = tbconfig[4]
self._cb_crc_length = tbconfig[5]
assert self._tb_size <= self._tb_crc_length + np.sum(self._cw_lengths),\
"Invalid TB parameters."
# due to quantization, the tb_size can slightly differ from the
# target tb_size.
self._k_padding = self._tb_size - self._target_tb_size
if self._tb_size != self._target_tb_size:
print(f"Note: actual tb_size={self._tb_size} is slightly "\
f"different than requested " \
f"target_tb_size={self._target_tb_size} due to "\
f"quantization. Internal zero padding will be applied.")
# calculate effective coderate (incl. CRC)
self._coderate = self._tb_size / self._num_coded_bits
# Remark: CRC16 is only used for k<3824 (otherwise CRC24)
if self._tb_crc_length==16:
self._tb_crc_encoder = CRCEncoder("CRC16")
else:
# CRC24A as defined in 7.2.1
self._tb_crc_encoder = CRCEncoder("CRC24A")
# CB CRC only if more than one CB is used
if self._cb_crc_length==24:
self._cb_crc_encoder = CRCEncoder("CRC24B")
else:
self._cb_crc_encoder = None
# scrambler can be deactivated (non-standard compliant)
if self._use_scrambler:
self._scrambler = TB5GScrambler(n_rnti=self._n_rnti,
n_id=self._n_id,
binary=True,
channel_type=channel_type,
codeword_index=codeword_index,
dtype=tf.float32,)
else: # required for TBDecoder
self._scrambler = None
# ---- Init LDPC encoder ----
# remark: as the codeword length can be (slightly) different
# within a TB due to rounding, we initialize the encoder
# with the max length and apply puncturing if required.
# Thus, also the output interleaver cannot be applied in the encoder.
# The procedure is defined in in 5.4.2.1 38.212
self._encoder = LDPC5GEncoder(self._cb_size,
np.max(self._cw_lengths),
num_bits_per_symbol=1) #deact. interleaver
# ---- Init interleaver ----
# remark: explicit interleaver is required as the rate matching from
# Sec. 5.4.2.1 38.212 could otherwise not be applied here
perm_seq_short, _ = self._encoder.generate_out_int(
np.min(self._cw_lengths),
num_bits_per_symbol)
perm_seq_long, _ = self._encoder.generate_out_int(
np.max(self._cw_lengths),
num_bits_per_symbol)
perm_seq = []
perm_seq_punc = []
# define one big interleaver that moves the punctured positions to the
# end of the TB
payload_bit_pos = 0 # points to current pos of payload bits
for l in self._cw_lengths:
if np.min(self._cw_lengths)==l:
perm_seq = np.concatenate([perm_seq,
perm_seq_short + payload_bit_pos])
# move unused bit positions to the end of TB
# this simplifies the inverse permutation
r = np.arange(payload_bit_pos+np.min(self._cw_lengths),
payload_bit_pos+np.max(self._cw_lengths))
perm_seq_punc = np.concatenate([perm_seq_punc, r])
# update pointer
payload_bit_pos += np.max(self._cw_lengths)
elif np.max(self._cw_lengths)==l:
perm_seq = np.concatenate([perm_seq,
perm_seq_long + payload_bit_pos])
# update pointer
payload_bit_pos += l
else:
raise ValueError("Invalid cw_lengths.")
# add punctured positions to end of sequence (only relevant for
# deinterleaving)
perm_seq = np.concatenate([perm_seq, perm_seq_punc])
self._output_perm = tf.constant(perm_seq, tf.int32)
self._output_perm_inv = tf.argsort(perm_seq, axis=-1)
#########################################
# Public methods and properties
#########################################
@property
def tb_size(self):
r"""Effective number of information bits per TB.
Note that (if required) internal zero padding can be applied to match
the request exact ``target_tb_size``."""
return self._tb_size
@property
def k(self):
r"""Number of input information bits. Equals `tb_size` except for zero
padding of the last positions if the ``target_tb_size`` is quantized."""
return self._target_tb_size
@property
def k_padding(self):
"""Number of zero padded bits at the end of the TB."""
return self._k_padding
@property
def n(self):
"Total number of output bits."
return self._num_coded_bits
@property
def num_cbs(self):
"Number code blocks."
return self._num_cbs
@property
def coderate(self):
"""Effective coderate of the TB after rate-matching including overhead
for the CRC."""
return self._coderate
@property
def ldpc_encoder(self):
"""LDPC encoder used for TB encoding."""
return self._encoder
@property
def scrambler(self):
"""Scrambler used for TB scrambling. `None` if no scrambler is used."""
return self._scrambler
@property
def tb_crc_encoder(self):
"""TB CRC encoder"""
return self._tb_crc_encoder
@property
def cb_crc_encoder(self):
"""CB CRC encoder. `None` if no CB CRC is applied."""
return self._cb_crc_encoder
@property
def num_tx(self):
"""Number of independent streams"""
return self._num_tx
@property
def cw_lengths(self):
r"""Each list element defines the codeword length of each of the
codewords after LDPC encoding and rate-matching. The total number of
coded bits is :math:`\sum` `cw_lengths`."""
return self._cw_lengths
@property
def output_perm_inv(self):
r"""Inverse interleaver pattern for output bit interleaver."""
return self._output_perm_inv
#########################
# Keras layer functions
#########################
def build(self, input_shapes):
"""Test input shapes for consistency."""
assert input_shapes[-1]==self.k, \
f"Invalid input shape. Expected TB length is {self.k}."
def call(self, inputs):
"""Apply transport block encoding procedure."""
# store shapes
input_shape = inputs.shape.as_list()
u = tf.cast(inputs, tf.float32)
# apply zero padding if tb_size is slightly different to target_tb_size
if self._k_padding>0:
s = tf.shape(u)
s = tf.concat((s[:-1], [self._k_padding]), axis=0)
u = tf.concat((u, tf.zeros(s, u.dtype)), axis=-1)
# apply TB CRC
u_crc = self._tb_crc_encoder(u)
# CB segmentation
u_cb = tf.reshape(u_crc,
(-1, self._num_tx, self._num_cbs,
self._cb_size-self._cb_crc_length))
# if relevant apply CB CRC
if self._cb_crc_length==24:
u_cb_crc = self._cb_crc_encoder(u_cb)
else:
u_cb_crc = u_cb # no CRC applied if only one CB exists
c_cb = self._encoder(u_cb_crc)
# CB concatenation
c = tf.reshape(c_cb,
(-1, self._num_tx,
self._num_cbs*np.max(self._cw_lengths)))
# apply interleaver (done after CB concatenation)
c = tf.gather(c, self._output_perm, axis=-1)
# puncture last bits
c = c[:, :, :np.sum(self._cw_lengths)]
# scrambler
if self._use_scrambler:
c_scr = self._scrambler(c)
else: # disable scrambler (non-standard compliant)
c_scr = c
# cast to output dtype
c_scr = tf.cast(c_scr, self.dtype)
# ensure output shapes
output_shape = input_shape
output_shape[0] = -1
output_shape[-1] = np.sum(self._cw_lengths)
c_tb = tf.reshape(c_scr, output_shape)
return c_tb