Source code for sionna.ofdm.demodulator

#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Class definition for the OFDM Demodulator"""

import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.signal import fftshift
from sionna.constants import PI
from sionna.utils import expand_to_rank
from sionna.signal import fft
import numpy as np

[docs] class OFDMDemodulator(Layer): # pylint: disable=line-too-long r""" OFDMDemodulator(fft_size, l_min, cyclic_prefix_length, **kwargs) Computes the frequency-domain representation of an OFDM waveform with cyclic prefix removal. The demodulator assumes that the input sequence is generated by the :class:`~sionna.channel.TimeChannel`. For a single pair of antennas, the received signal sequence is given as: .. math:: y_b = \sum_{\ell =L_\text{min}}^{L_\text{max}} \bar{h}_\ell x_{b-\ell} + w_b, \quad b \in[L_\text{min}, N_B+L_\text{max}-1] where :math:`\bar{h}_\ell` are the discrete-time channel taps, :math:`x_{b}` is the the transmitted signal, and :math:`w_\ell` Gaussian noise. Starting from the first symbol, the demodulator cuts the input sequence into pieces of size ``cyclic_prefix_length + fft_size``, and throws away any trailing symbols. For each piece, the cyclic prefix is removed and the ``fft_size``-point discrete Fourier transform is computed. It is also possible that every OFDM symbol has a cyclic prefix of different length. Since the input sequence starts at time :math:`L_\text{min}`, the FFT-window has a timing offset of :math:`L_\text{min}` symbols, which leads to a subcarrier-dependent phase shift of :math:`e^{\frac{j2\pi k L_\text{min}}{N}}`, where :math:`k` is the subcarrier index, :math:`N` is the FFT size, and :math:`L_\text{min} \le 0` is the largest negative time lag of the discrete-time channel impulse response. This phase shift is removed in this layer, by explicitly multiplying each subcarrier by :math:`e^{\frac{-j2\pi k L_\text{min}}{N}}`. This is a very important step to enable channel estimation with sparse pilot patterns that needs to interpolate the channel frequency response accross subcarriers. It also ensures that the channel frequency response `seen` by the time-domain channel is close to the :class:`~sionna.channel.OFDMChannel`. Parameters ---------- fft_size : int FFT size (, i.e., the number of subcarriers). l_min : int The largest negative time lag of the discrete-time channel impulse response. It should be the same value as that used by the `cir_to_time_channel` function. cyclic_prefix_length : scalar or [num_ofdm_symbols], int Integer or vector of integers indicating the length of the cyclic prefix that is prepended to each OFDM symbol. None of its elements can be larger than the FFT size. Defaults to 0. Input ----- :[...,num_ofdm_symbols*(fft_size+cyclic_prefix_length)+n] or [...,num_ofdm_symbols*fft_size+sum(cyclic_prefix_length)+n], tf.complex Tensor containing the time-domain signal along the last dimension. `n` is a nonnegative integer. Output ------ :[...,num_ofdm_symbols,fft_size], tf.complex Tensor containing the OFDM resource grid along the last two dimension. """ def __init__(self, fft_size, l_min, cyclic_prefix_length=0, **kwargs): super().__init__(**kwargs) self._fft_size = None self._l_min = None self._cyclic_prefix_length = None self.fft_size = fft_size self.l_min = l_min self.cyclic_prefix_length = cyclic_prefix_length @property def fft_size(self): return self._fft_size @fft_size.setter def fft_size(self, value): assert value>0, "`fft_size` must be positive." self._fft_size = int(value) @property def l_min(self): return self._l_min @l_min.setter def l_min(self, value): assert value<=0, "l_min must be nonpositive." self._l_min = int(value) @property def cyclic_prefix_length(self): return self._cyclic_prefix_length @cyclic_prefix_length.setter def cyclic_prefix_length(self, value): value = tf.cast(value, tf.int32) if not tf.reduce_all(value>=0): msg = "`cyclic_prefix_length` must be nonnegative." raise ValueError(msg) if not 0<= tf.rank(value)<=1: msg = "`cyclic_prefix_length` must be of rank 0 or 1" raise ValueError(msg) self._cyclic_prefix_length = value def build(self, input_shape): # pylint: disable=unused-argument # Compute phase correction terms to to channel tmp = -2 * PI * tf.cast(self.l_min, tf.float32) \ / tf.cast(self.fft_size, tf.float32) \ * tf.range(self.fft_size, dtype=tf.float32) self._phase_compensation = tf.exp(tf.complex(0., tmp)) if len(self.cyclic_prefix_length.shape)==0: # Compute number of elements that will be truncated self._rest = np.mod(input_shape[-1], self.fft_size + self.cyclic_prefix_length) # Compute number of full OFDM symbols to be demodulated self._num_ofdm_symbols = np.floor_divide( input_shape[-1]-self._rest, self.fft_size + self.cyclic_prefix_length) else: # Deal with individual cp lengths for OFDM symbols # Compute the relevant indices to gather for # every OFDM symbol from the time domain input num_ofdm_symbols = self.cyclic_prefix_length.shape[0] row_lengths = self.cyclic_prefix_length + self.fft_size offsets = tf.math.cumsum(tf.concat([[0], row_lengths], axis=0)[:-1]) offsets = tf.expand_dims(offsets, 1) ind = tf.repeat(tf.range(start=0, limit=self.fft_size)[tf.newaxis,:], repeats=num_ofdm_symbols, axis=0) ind += self.cyclic_prefix_length[:, tf.newaxis] ind += offsets # [num_ofdm_symbols, fft_size] self._ind = ind def call(self, inputs): """Demodulate OFDM waveform onto a resource grid. Args: inputs (tf.complex64): `[...,num_ofdm_symbols*(fft_size+cyclic_prefix_length)]`. Returns: `tf.complex64` : The demodulated inputs of shape `[...,num_ofdm_symbols, fft_size]`. """ if len(self.cyclic_prefix_length.shape)==0: # Same CP length for all OFDM symbols # Cut last samples that do not fit into an OFDM symbol inputs = inputs if self._rest==0 else inputs[...,:-self._rest] # Reshape input to separate OFDM symbols new_shape = tf.concat( [tf.shape(inputs)[:-1], [self._num_ofdm_symbols], [self.fft_size + self.cyclic_prefix_length]], 0) x = tf.reshape(inputs, new_shape) # Remove cyclic prefix x = x[...,self.cyclic_prefix_length:] else: # Individual CP length for OFDM symbols x = tf.gather(inputs, self._ind, axis=-1) # Compute FFT x = fft(x) # Apply phase shift compensation to all subcarriers rot = tf.cast(self._phase_compensation, x.dtype) rot = expand_to_rank(rot, tf.rank(x), 0) x = x * rot # Shift DC subcarrier to the middle x = fftshift(x, axes=-1) return x