#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Class definition for the OFDM Modulator"""
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.signal import ifftshift
from sionna.utils import flatten_last_dims
from sionna.signal import ifft
[docs]
class OFDMModulator(Layer):
# pylint: disable=line-too-long
"""
OFDMModulator(cyclic_prefix_length=0, **kwargs)
Computes the time-domain representation of an OFDM resource grid
with (optional) cyclic prefix
Parameters
----------
cyclic_prefix_length : scalar or [num_ofdm_symbols], int
Integer or vector of integers indicating the length of the
cyclic prefix that is prepended to each OFDM symbol. None of its
elements can be larger than the FFT size.
Defaults to 0.
Input
-----
: [...,num_ofdm_symbols,fft_size], tf.complex
Resource grid in the frequency domain
Output
------
: [...,num_ofdm_symbols*(fft_size+cyclic_prefix_length)] or [...,num_ofdm_symbols*fft_size+sum(cyclic_prefix_length)], tf.complex
Time-domain OFDM signal
"""
def __init__(self, cyclic_prefix_length=0, **kwargs):
super().__init__(**kwargs)
self._cyclic_prefix_length = None
self.cyclic_prefix_length = cyclic_prefix_length
@property
def cyclic_prefix_length(self):
"""
scalar or [num_ofdm_symbols], int : Get/set the cyclic prefix length
"""
return self._cyclic_prefix_length
@cyclic_prefix_length.setter
def cyclic_prefix_length(self, value):
value = tf.cast(value, tf.int32)
if not tf.reduce_all(value>=0):
msg = "`cyclic_prefix_length` must be nonnegative."
raise ValueError(msg)
if not 0<= tf.rank(value)<=1:
msg = "`cyclic_prefix_length` must be of rank 0 or 1"
raise ValueError(msg)
self._cyclic_prefix_length = value
def build(self, input_shape):
num_ofdm_symbols, fft_size = input_shape[-2:]
if not tf.reduce_all(self.cyclic_prefix_length<=fft_size):
msg = "`cyclic_prefix_length` cannot be larger than `fft_size`."
raise ValueError(msg)
if len(self.cyclic_prefix_length.shape)==1:
if not self.cyclic_prefix_length.shape[0]==num_ofdm_symbols:
msg = "`cyclic_prefix_length` must be of size [num_ofdm_symbols]"
raise ValueError(msg)
# Compute indices of CP symbols
# These are offset by the number of the OFDM symbol
# [num_ofdm_symbols, 1]
offsets = tf.expand_dims(tf.range(1, num_ofdm_symbols+1)*fft_size,
1)
# [num_ofdm_symbols, None] (ragged tensor)
cp_ind = tf.ragged.range(starts=-self.cyclic_prefix_length,
limits=0) + offsets
# Compute indices of symbols containing the actual sequence
# [num_ofdm_symbols, fft_size]
data_ind = tf.repeat(tf.expand_dims(tf.range(0, fft_size), 0),
num_ofdm_symbols, 0) + offsets - fft_size
# Concat CP and sequence indices
# [num_ofdm_symbols, None]
ind = tf.concat([cp_ind, data_ind], axis=-1)
# Flatten in time domain
# [num_ofdm_symbols *fft_size + sum(cyclic_prefix_length)]
self._ind = ind.flat_values
def call(self, inputs):
# Shift DC subcarrier to first position
x_freq = ifftshift(inputs, axes=-1)
# Compute IFFT along the last dimension
x_time = ifft(x_freq)
if len(self.cyclic_prefix_length.shape)==1:
# Individual CP length per OFDM symbol
# Flatten last two dimensions
x_time = flatten_last_dims(x_time, 2)
# Gather full time-domain signal
return tf.gather(x_time, self._ind, axis=-1)
else:
# Same CP length for all OFDM symbols
# Obtain cyclic prefix
cp = x_time[...,tf.shape(x_time)[-1]-self._cyclic_prefix_length:]
# Prepend cyclic prefix
x_time = tf.concat([cp, x_time], -1)
# Serialize last two dimensions
return flatten_last_dims(x_time, 2)