Source code for sionna.nr.pusch_pilot_pattern

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""PUSCH pilot pattern for the nr (5G) sub-package of the Sionna library.
"""
import warnings
from collections.abc import Sequence
import tensorflow as tf
import numpy as np
from sionna.ofdm import PilotPattern
from .pusch_config import PUSCHConfig

[docs] class PUSCHPilotPattern(PilotPattern): # pylint: disable=line-too-long r"""Class defining a pilot pattern for NR PUSCH. This class defines a :class:`~sionna.ofdm.PilotPattern` that is used to configure an OFDM :class:`~sionna.ofdm.ResourceGrid`. For every transmitter, a separte :class:`~sionna.nr.PUSCHConfig` needs to be provided from which the pilot pattern will be created. Parameters ---------- pusch_configs : instance or list of :class:`~sionna.nr.PUSCHConfig` PUSCH Configurations according to which the pilot pattern will created. One configuration is needed for each transmitter. dtype : tf.Dtype Defines the datatype for internal calculations and the output dtype. Defaults to `tf.complex64`. """ def __init__(self, pusch_configs, dtype=tf.complex64): # Check correct type of pusch_configs if isinstance(pusch_configs, PUSCHConfig): pusch_configs = [pusch_configs] elif isinstance(pusch_configs, Sequence): for c in pusch_configs: assert isinstance(c, PUSCHConfig), \ "Each element of pusch_configs must be a valide PUSCHConfig" else: raise ValueError("Invalid value for pusch_configs") # Check validity of provided pusch_configs num_tx = len(pusch_configs) num_streams_per_tx = pusch_configs[0].num_layers dmrs_grid = pusch_configs[0].dmrs_grid num_subcarriers = dmrs_grid[0].shape[0] num_ofdm_symbols = pusch_configs[0].l_d precoding = pusch_configs[0].precoding dmrs_ports = [] num_pilots = np.sum(pusch_configs[0].dmrs_mask) for pusch_config in pusch_configs: assert pusch_config.num_layers==num_streams_per_tx, \ "All pusch_configs must have the same number of layers" assert pusch_config.dmrs_grid[0].shape[0]==num_subcarriers, \ "All pusch_configs must have the same number of subcarriers" assert pusch_config.l_d==num_ofdm_symbols, \ "All pusch_configs must have the same number of OFDM symbols" assert pusch_config.precoding==precoding, \ "All pusch_configs must have a the same precoding method" assert np.sum(pusch_config.dmrs_mask)==num_pilots, \ "All pusch_configs must have a the same number of masked REs" with warnings.catch_warnings(): warnings.simplefilter('always') for port in pusch_config.dmrs.dmrs_port_set: if port in dmrs_ports: msg = f"DMRS port {port} used by multiple transmitters" warnings.warn(msg) dmrs_ports += pusch_config.dmrs.dmrs_port_set # Create mask and pilots tensors mask = np.zeros([num_tx, num_streams_per_tx, num_ofdm_symbols, num_subcarriers], bool) num_pilots = np.sum(pusch_configs[0].dmrs_mask) pilots = np.zeros([num_tx, num_streams_per_tx, num_pilots], complex) for i, pusch_config in enumerate(pusch_configs): for j in range(num_streams_per_tx): ind0, ind1 = pusch_config.symbol_allocation mask[i,j] = np.transpose( pusch_config.dmrs_mask[:, ind0:ind0+ind1]) dmrs_grid = np.transpose( pusch_config.dmrs_grid[j, :, ind0:ind0+ind1]) pilots[i,j] = dmrs_grid[np.where(mask[i,j])] # Init PilotPattern class super().__init__(mask, pilots, trainable=False, normalize=False, dtype=dtype)