Source code for sionna.nr.pusch_receiver

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""PUSCH Receiver for the nr (5G) sub-package of the Sionna library.
"""
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer
import sionna
from sionna.mimo import StreamManagement
from sionna.ofdm import OFDMDemodulator, LinearDetector
from sionna.utils import insert_dims
from sionna.channel import time_to_ofdm_channel

[docs] class PUSCHReceiver(Layer): # pylint: disable=line-too-long r"""PUSCHReceiver(pusch_transmitter, channel_estimator=None, mimo_detector=None, tb_decoder=None, return_tb_crc_status=False, stream_management=None, input_domain="freq", l_min=None, dtype=tf.complex64, **kwargs) This layer implements a full receiver for batches of 5G NR PUSCH slots sent by multiple transmitters. Inputs can be in the time or frequency domain. Perfect channel state information can be optionally provided. Different channel estimatiors, MIMO detectors, and transport decoders can be configured. The layer combines multiple processing blocks into a single layer as shown in the following figure. Blocks with dashed lines are optional and depend on the configuration. .. figure:: ../figures/pusch_receiver_block_diagram.png :scale: 30% :align: center If the ``input_domain`` equals "time", the inputs :math:`\mathbf{y}` are first transformed to resource grids with the :class:`~sionna.ofdm.OFDMDemodulator`. Then channel estimation is performed, e.g., with the help of the :class:`~sionna.nr.PUSCHLSChannelEstimator`. If ``channel_estimator`` is chosen to be "perfect", this step is skipped and the input :math:`\mathbf{h}` is used instead. Next, MIMO detection is carried out with an arbitrary :class:`~sionna.ofdm.OFDMDetector`. The resulting LLRs for each layer are then combined to transport blocks with the help of the :class:`~sionna.nr.LayerDemapper`. Finally, the transport blocks are decoded with the :class:`~sionna.nr.TBDecoder`. Parameters ---------- pusch_transmitter : :class:`~sionna.nr.PUSCHTransmitter` Transmitter used for the generation of the transmit signals channel_estimator : :class:`~sionna.ofdm.BaseChannelEstimator`, "perfect", or `None` Channel estimator to be used. If `None`, the :class:`~sionna.nr.PUSCHLSChannelEstimator` with linear interpolation is used. If "perfect", no channel estimation is performed and the channel state information ``h`` must be provided as additional input. Defaults to `None`. mimo_detector : :class:`~sionna.ofdm.OFDMDetector` or `None` MIMO Detector to be used. If `None`, the :class:`~sionna.ofdm.LinearDetector` with LMMSE detection is used. Defaults to `None`. tb_decoder : :class:`~sionna.nr.TBDecoder` or `None` Transport block decoder to be used. If `None`, the :class:`~sionna.nr.TBDecoder` with its default settings is used. Defaults to `None`. return_tb_crc_status : bool If `True`, the status of the transport block CRC is returned as additional output. Defaults to `False`. stream_management : :class:`~sionna.mimo.StreamManagement` or `None` Stream management configuration to be used. If `None`, it is assumed that there is a single receiver which decodes all streams of all transmitters. Defaults to `None`. input_domain : str, one of ["freq", "time"] Domain of the input signal. Defaults to "freq". l_min : int or `None` Smallest time-lag for the discrete complex baseband channel. Only needed if ``input_domain`` equals "time". Defaults to `None`. dtype : tf.Dtype Datatype for internal calculations and the output dtype. Defaults to `tf.complex64`. Input ----- (y, h, no) : Tuple: y : [batch size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex or [batch size, num_rx, num_rx_ant, num_time_samples + l_max - l_min], tf.complex Frequency- or time-domain input signal h : [batch size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_ofdm_symbols, num_subcarriers], tf.complex or [batch size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_time_samples + l_max - l_min, l_max - l_min + 1], tf.complex Perfect channel state information in either frequency or time domain (depending on ``input_domain``) to be used for detection. Only required if ``channel_estimator`` equals "perfect". no : [batch_size, num_rx, num_rx_ant] or only the first n>=0 dims, tf.float Variance of the AWGN Output ------ b_hat : [batch_size, num_tx, tb_size], tf.float Decoded information bits tb_crc_status : [batch_size, num_tx], tf.bool Transport block CRC status Example ------- >>> pusch_config = PUSCHConfig() >>> pusch_transmitter = PUSCHTransmitter(pusch_config) >>> pusch_receiver = PUSCHReceiver(pusch_transmitter) >>> channel = AWGN() >>> x, b = pusch_transmitter(16) >>> no = 0.1 >>> y = channel([x, no]) >>> b_hat = pusch_receiver([x, no]) >>> compute_ber(b, b_hat) <tf.Tensor: shape=(), dtype=float64, numpy=0.0> """ def __init__(self, pusch_transmitter, channel_estimator=None, mimo_detector=None, tb_decoder=None, return_tb_crc_status=False, stream_management=None, input_domain="freq", l_min=None, dtype=tf.complex64, **kwargs): assert dtype in [tf.complex64, tf.complex128], \ "dtype must be tf.complex64 or tf.complex128" super().__init__(dtype=dtype, **kwargs) assert input_domain in ["time", "freq"], \ "input_domain must be 'time' or 'freq'" self._input_domain = input_domain self._return_tb_crc_status = return_tb_crc_status self._resource_grid = pusch_transmitter.resource_grid # (Optionally) Create OFDMDemodulator if self._input_domain=="time": assert l_min is not None, \ "l_min must be provided for input_domain==time" self._l_min = l_min self._ofdm_demodulator = OFDMDemodulator( fft_size=pusch_transmitter._num_subcarriers, l_min=self._l_min, cyclic_prefix_length=pusch_transmitter._cyclic_prefix_length) # Use or create default ChannelEstimator self._perfect_csi = False self._w = None if channel_estimator is None: # Default channel estimator self._channel_estimator = sionna.nr.PUSCHLSChannelEstimator( self.resource_grid, pusch_transmitter._dmrs_length, pusch_transmitter._dmrs_additional_position, pusch_transmitter._num_cdm_groups_without_data, interpolation_type='lin', dtype=dtype) elif channel_estimator=="perfect": # Perfect channel estimation self._perfect_csi = True if pusch_transmitter._precoding=="codebook": self._w = pusch_transmitter._precoder._w self._w = insert_dims(self._w, 2, 1) else: # User-provided channel estimator self._channel_estimator = channel_estimator # Use or create default StreamManagement if stream_management is None: # Default StreamManagement rx_tx_association = np.ones([1, pusch_transmitter._num_tx], bool) self._stream_management = StreamManagement( rx_tx_association, pusch_transmitter._num_layers) else: # User-provided StramManagement self._stream_management = stream_management # Use or create default MIMODetector if mimo_detector is None: # Default MIMO detector self._mimo_detector = LinearDetector("lmmse", "bit", "maxlog", pusch_transmitter.resource_grid, self._stream_management, "qam", pusch_transmitter._num_bits_per_symbol, dtype=dtype) else: # User-provided MIMO detector self._mimo_detector = mimo_detector # Create LayerDemapper self._layer_demapper = sionna.nr.LayerDemapper( pusch_transmitter._layer_mapper, num_bits_per_symbol=pusch_transmitter._num_bits_per_symbol) # Use or create default TBDecoder if tb_decoder is None: # Default TBEncoder self._tb_decoder = sionna.nr.TBDecoder( pusch_transmitter._tb_encoder, output_dtype=dtype.real_dtype) else: # User-provided TBEncoder self._tb_decoder = tb_decoder ######################################### # Public methods and properties ######################################### @property def resource_grid(self): """OFDM resource grid underlying the PUSCH transmissions""" return self._resource_grid def call(self, inputs): if self._perfect_csi: y, h, no = inputs else: y, no = inputs # (Optional) OFDM Demodulation if self._input_domain=="time": y = self._ofdm_demodulator(y) # Channel estimation if self._perfect_csi: # Transform time-domain to frequency-domain channel if self._input_domain=="time": h = time_to_ofdm_channel(h, self.resource_grid, self._l_min) if self._w is not None: # Reshape h to put channel matrix dimensions last # [batch size, num_rx, num_tx, num_ofdm_symbols,... # ...fft_size, num_rx_ant, num_tx_ant] h = tf.transpose(h, perm=[0,1,3,5,6,2,4]) # Multiply by precoding matrices to compute effective channels # [batch size, num_rx, num_tx, num_ofdm_symbols,... # ...fft_size, num_rx_ant, num_streams] h = tf.matmul(h, self._w) # Reshape # [batch size, num_rx, num_rx_ant, num_tx, num_streams,... # ...num_ofdm_symbols, fft_size] h = tf.transpose(h, perm=[0,1,5,2,6,3,4]) h_hat = h err_var = tf.cast(0, dtype=h_hat.dtype.real_dtype) else: h_hat,err_var = self._channel_estimator([y, no]) # MIMO Detection llr = self._mimo_detector([y, h_hat, err_var, no]) # Layer demapping llr = self._layer_demapper(llr) # TB Decoding b_hat, tb_crc_status = self._tb_decoder(llr) if self._return_tb_crc_status: return b_hat, tb_crc_status else: return b_hat