#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Class definition and functions related to OFDM channel equalization"""
import tensorflow as tf
from tensorflow.keras.layers import Layer
import sionna
from sionna.utils import flatten_dims, split_dim, flatten_last_dims, expand_to_rank
from sionna.mimo import lmmse_equalizer, zf_equalizer, mf_equalizer
from sionna.ofdm import RemoveNulledSubcarriers
[docs]
class OFDMEqualizer(Layer):
# pylint: disable=line-too-long
r"""OFDMEqualizer(equalizer, resource_grid, stream_management, dtype=tf.complex64, **kwargs)
Layer that wraps a MIMO equalizer for use with the OFDM waveform.
The parameter ``equalizer`` is a callable (e.g., a function) that
implements a MIMO equalization algorithm for arbitrary batch dimensions.
This class pre-processes the received resource grid ``y`` and channel
estimate ``h_hat``, and computes for each receiver the
noise-plus-interference covariance matrix according to the OFDM and stream
configuration provided by the ``resource_grid`` and
``stream_management``, which also accounts for the channel
estimation error variance ``err_var``. These quantities serve as input
to the equalization algorithm that is implemented by the callable ``equalizer``.
This layer computes soft-symbol estimates together with effective noise
variances for all streams which can, e.g., be used by a
:class:`~sionna.mapping.Demapper` to obtain LLRs.
Note
-----
The callable ``equalizer`` must take three inputs:
* **y** ([...,num_rx_ant], tf.complex) -- 1+D tensor containing the received signals.
* **h** ([...,num_rx_ant,num_streams_per_rx], tf.complex) -- 2+D tensor containing the channel matrices.
* **s** ([...,num_rx_ant,num_rx_ant], tf.complex) -- 2+D tensor containing the noise-plus-interference covariance matrices.
It must generate two outputs:
* **x_hat** ([...,num_streams_per_rx], tf.complex) -- 1+D tensor representing the estimated symbol vectors.
* **no_eff** (tf.float) -- Tensor of the same shape as ``x_hat`` containing the effective noise variance estimates.
Parameters
----------
equalizer : Callable
Callable object (e.g., a function) that implements a MIMO equalization
algorithm for arbitrary batch dimensions
resource_grid : ResourceGrid
Instance of :class:`~sionna.ofdm.ResourceGrid`
stream_management : StreamManagement
Instance of :class:`~sionna.mimo.StreamManagement`
dtype : tf.Dtype
Datatype for internal calculations and the output dtype.
Defaults to `tf.complex64`.
Input
-----
(y, h_hat, err_var, no) :
Tuple:
y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
Received OFDM resource grid after cyclic prefix removal and FFT
h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
Channel estimates for all streams from all transmitters
err_var : [Broadcastable to shape of ``h_hat``], tf.float
Variance of the channel estimation error
no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
Variance of the AWGN
Output
------
x_hat : [batch_size, num_tx, num_streams, num_data_symbols], tf.complex
Estimated symbols
no_eff : [batch_size, num_tx, num_streams, num_data_symbols], tf.float
Effective noise variance for each estimated symbol
"""
def __init__(self,
equalizer,
resource_grid,
stream_management,
dtype=tf.complex64,
**kwargs):
super().__init__(dtype=dtype, **kwargs)
assert callable(equalizer)
assert isinstance(resource_grid, sionna.ofdm.ResourceGrid)
assert isinstance(stream_management, sionna.mimo.StreamManagement)
self._equalizer = equalizer
self._resource_grid = resource_grid
self._stream_management = stream_management
self._removed_nulled_scs = RemoveNulledSubcarriers(self._resource_grid)
# Precompute indices to extract data symbols
mask = resource_grid.pilot_pattern.mask
num_data_symbols = resource_grid.pilot_pattern.num_data_symbols
data_ind = tf.argsort(flatten_last_dims(mask), direction="ASCENDING")
self._data_ind = data_ind[...,:num_data_symbols]
def call(self, inputs):
y, h_hat, err_var, no = inputs
# y has shape:
# [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size]
# h_hat has shape:
# [batch_size, num_rx, num_rx_ant, num_tx, num_streams,...
# ..., num_ofdm_symbols, num_effective_subcarriers]
# err_var has a shape that is broadcastable to h_hat
# no has shape [batch_size, num_rx, num_rx_ant]
# or just the first n dimensions of this
# Remove nulled subcarriers from y (guards, dc). New shape:
# [batch_size, num_rx, num_rx_ant, ...
# ..., num_ofdm_symbols, num_effective_subcarriers]
y_eff = self._removed_nulled_scs(y)
####################################################
### Prepare the observation y for MIMO detection ###
####################################################
# Transpose y_eff to put num_rx_ant last. New shape:
# [batch_size, num_rx, num_ofdm_symbols,...
# ..., num_effective_subcarriers, num_rx_ant]
y_dt = tf.transpose(y_eff, [0, 1, 3, 4, 2])
y_dt = tf.cast(y_dt, self._dtype)
##############################################
### Prepare the err_var for MIMO detection ###
##############################################
# New shape is:
# [batch_size, num_rx, num_ofdm_symbols,...
# ..., num_effective_subcarriers, num_rx_ant, num_tx*num_streams]
err_var_dt = tf.broadcast_to(err_var, tf.shape(h_hat))
err_var_dt = tf.transpose(err_var_dt, [0, 1, 5, 6, 2, 3, 4])
err_var_dt = flatten_last_dims(err_var_dt, 2)
err_var_dt = tf.cast(err_var_dt, self._dtype)
###############################
### Construct MIMO channels ###
###############################
# Reshape h_hat for the construction of desired/interfering channels:
# [num_rx, num_tx, num_streams_per_tx, batch_size, num_rx_ant, ,...
# ..., num_ofdm_symbols, num_effective_subcarriers]
perm = [1, 3, 4, 0, 2, 5, 6]
h_dt = tf.transpose(h_hat, perm)
# Flatten first tthree dimensions:
# [num_rx*num_tx*num_streams_per_tx, batch_size, num_rx_ant, ...
# ..., num_ofdm_symbols, num_effective_subcarriers]
h_dt = flatten_dims(h_dt, 3, 0)
# Gather desired and undesired channels
ind_desired = self._stream_management.detection_desired_ind
ind_undesired = self._stream_management.detection_undesired_ind
h_dt_desired = tf.gather(h_dt, ind_desired, axis=0)
h_dt_undesired = tf.gather(h_dt, ind_undesired, axis=0)
# Split first dimension to separate RX and TX:
# [num_rx, num_streams_per_rx, batch_size, num_rx_ant, ...
# ..., num_ofdm_symbols, num_effective_subcarriers]
h_dt_desired = split_dim(h_dt_desired,
[self._stream_management.num_rx,
self._stream_management.num_streams_per_rx],
0)
h_dt_undesired = split_dim(h_dt_undesired,
[self._stream_management.num_rx, -1], 0)
# Permutate dims to
# [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,..
# ..., num_rx_ant, num_streams_per_rx(num_Interfering_streams_per_rx)]
perm = [2, 0, 4, 5, 3, 1]
h_dt_desired = tf.transpose(h_dt_desired, perm)
h_dt_desired = tf.cast(h_dt_desired, self._dtype)
h_dt_undesired = tf.transpose(h_dt_undesired, perm)
##################################
### Prepare the noise variance ###
##################################
# no is first broadcast to [batch_size, num_rx, num_rx_ant]
# then the rank is expanded to that of y
# then it is transposed like y to the final shape
# [batch_size, num_rx, num_ofdm_symbols,...
# ..., num_effective_subcarriers, num_rx_ant]
no_dt = expand_to_rank(no, 3, -1)
no_dt = tf.broadcast_to(no_dt, tf.shape(y)[:3])
no_dt = expand_to_rank(no_dt, tf.rank(y), -1)
no_dt = tf.transpose(no_dt, [0,1,3,4,2])
no_dt = tf.cast(no_dt, self._dtype)
##################################################
### Compute the interference covariance matrix ###
##################################################
# Covariance of undesired transmitters
s_inf = tf.matmul(h_dt_undesired, h_dt_undesired, adjoint_b=True)
#Thermal noise
s_no = tf.linalg.diag(no_dt)
# Channel estimation errors
# As we have only error variance information for each element,
# we simply sum them across transmitters and build a
# diagonal covariance matrix from this
s_csi = tf.linalg.diag(tf.reduce_sum(err_var_dt, -1))
# Final covariance matrix
s = s_inf + s_no + s_csi
s = tf.cast(s, self._dtype)
############################################################
### Compute symbol estimate and effective noise variance ###
############################################################
# [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,...
# ..., num_stream_per_rx]
x_hat, no_eff = self._equalizer(y_dt, h_dt_desired, s)
################################################
### Extract data symbols for all detected TX ###
################################################
# Transpose tensor to shape
# [num_rx, num_streams_per_rx, num_ofdm_symbols,...
# ..., num_effective_subcarriers, batch_size]
x_hat = tf.transpose(x_hat, [1, 4, 2, 3, 0])
no_eff = tf.transpose(no_eff, [1, 4, 2, 3, 0])
# Merge num_rx amd num_streams_per_rx
# [num_rx * num_streams_per_rx, num_ofdm_symbols,...
# ...,num_effective_subcarriers, batch_size]
x_hat = flatten_dims(x_hat, 2, 0)
no_eff = flatten_dims(no_eff, 2, 0)
# Put first dimension into the right ordering
stream_ind = self._stream_management.stream_ind
x_hat = tf.gather(x_hat, stream_ind, axis=0)
no_eff = tf.gather(no_eff, stream_ind, axis=0)
# Reshape first dimensions to [num_tx, num_streams] so that
# we can compared to the way the streams were created.
# [num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers,...
# ..., batch_size]
num_streams = self._stream_management.num_streams_per_tx
num_tx = self._stream_management.num_tx
x_hat = split_dim(x_hat, [num_tx, num_streams], 0)
no_eff = split_dim(no_eff, [num_tx, num_streams], 0)
# Flatten resource grid dimensions
# [num_tx, num_streams, num_ofdm_symbols*num_effective_subcarriers,...
# ..., batch_size]
x_hat = flatten_dims(x_hat, 2, 2)
no_eff = flatten_dims(no_eff, 2, 2)
# Broadcast no_eff to the shape of x_hat
no_eff = tf.broadcast_to(no_eff, tf.shape(x_hat))
# Gather data symbols
# [num_tx, num_streams, num_data_symbols, batch_size]
x_hat = tf.gather(x_hat, self._data_ind, batch_dims=2, axis=2)
no_eff = tf.gather(no_eff, self._data_ind, batch_dims=2, axis=2)
# Put batch_dim first
# [batch_size, num_tx, num_streams, num_data_symbols]
x_hat = tf.transpose(x_hat, [3, 0, 1, 2])
no_eff = tf.transpose(no_eff, [3, 0, 1, 2])
return (x_hat, no_eff)
[docs]
class LMMSEEqualizer(OFDMEqualizer):
# pylint: disable=line-too-long
"""LMMSEEqualizer(resource_grid, stream_management, whiten_interference=True, dtype=tf.complex64, **kwargs)
LMMSE equalization for OFDM MIMO transmissions.
This layer computes linear minimum mean squared error (LMMSE) equalization
for OFDM MIMO transmissions. The OFDM and stream configuration are provided
by a :class:`~sionna.ofdm.ResourceGrid` and
:class:`~sionna.mimo.StreamManagement` instance, respectively. The
detection algorithm is the :meth:`~sionna.mimo.lmmse_equalizer`. The layer
computes soft-symbol estimates together with effective noise variances
for all streams which can, e.g., be used by a
:class:`~sionna.mapping.Demapper` to obtain LLRs.
Parameters
----------
resource_grid : ResourceGrid
Instance of :class:`~sionna.ofdm.ResourceGrid`
stream_management : StreamManagement
Instance of :class:`~sionna.mimo.StreamManagement`
whiten_interference : bool
If `True` (default), the interference is first whitened before equalization.
In this case, an alternative expression for the receive filter is used which
can be numerically more stable.
dtype : tf.Dtype
Datatype for internal calculations and the output dtype.
Defaults to `tf.complex64`.
Input
-----
(y, h_hat, err_var, no) :
Tuple:
y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
Received OFDM resource grid after cyclic prefix removal and FFT
h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
Channel estimates for all streams from all transmitters
err_var : [Broadcastable to shape of ``h_hat``], tf.float
Variance of the channel estimation error
no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
Variance of the AWGN
Output
------
x_hat : [batch_size, num_tx, num_streams, num_data_symbols], tf.complex
Estimated symbols
no_eff : [batch_size, num_tx, num_streams, num_data_symbols], tf.float
Effective noise variance for each estimated symbol
Note
----
If you want to use this layer in Graph mode with XLA, i.e., within
a function that is decorated with ``@tf.function(jit_compile=True)``,
you must set ``sionna.Config.xla_compat=true``.
See :py:attr:`~sionna.Config.xla_compat`.
"""
def __init__(self,
resource_grid,
stream_management,
whiten_interference=True,
dtype=tf.complex64,
**kwargs):
def equalizer(y, h, s):
return lmmse_equalizer(y, h, s, whiten_interference)
super().__init__(equalizer=equalizer,
resource_grid=resource_grid,
stream_management=stream_management,
dtype=dtype, **kwargs)
[docs]
class ZFEqualizer(OFDMEqualizer):
# pylint: disable=line-too-long
"""ZFEqualizer(resource_grid, stream_management, dtype=tf.complex64, **kwargs)
ZF equalization for OFDM MIMO transmissions.
This layer computes zero-forcing (ZF) equalization
for OFDM MIMO transmissions. The OFDM and stream configuration are provided
by a :class:`~sionna.ofdm.ResourceGrid` and
:class:`~sionna.mimo.StreamManagement` instance, respectively. The
detection algorithm is the :meth:`~sionna.mimo.zf_equalizer`. The layer
computes soft-symbol estimates together with effective noise variances
for all streams which can, e.g., be used by a
:class:`~sionna.mapping.Demapper` to obtain LLRs.
Parameters
----------
resource_grid : ResourceGrid
An instance of :class:`~sionna.ofdm.ResourceGrid`.
stream_management : StreamManagement
An instance of :class:`~sionna.mimo.StreamManagement`.
dtype : tf.Dtype
Datatype for internal calculations and the output dtype.
Defaults to `tf.complex64`.
Input
-----
(y, h_hat, err_var, no) :
Tuple:
y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
Received OFDM resource grid after cyclic prefix removal and FFT
h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
Channel estimates for all streams from all transmitters
err_var : [Broadcastable to shape of ``h_hat``], tf.float
Variance of the channel estimation error
no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
Variance of the AWGN
Output
------
x_hat : [batch_size, num_tx, num_streams, num_data_symbols], tf.complex
Estimated symbols
no_eff : [batch_size, num_tx, num_streams, num_data_symbols], tf.float
Effective noise variance for each estimated symbol
Note
----
If you want to use this layer in Graph mode with XLA, i.e., within
a function that is decorated with ``@tf.function(jit_compile=True)``,
you must set ``sionna.Config.xla_compat=true``.
See :py:attr:`~sionna.Config.xla_compat`.
"""
def __init__(self,
resource_grid,
stream_management,
dtype=tf.complex64,
**kwargs):
super().__init__(equalizer=zf_equalizer,
resource_grid=resource_grid,
stream_management=stream_management,
dtype=dtype, **kwargs)
[docs]
class MFEqualizer(OFDMEqualizer):
# pylint: disable=line-too-long
"""MFEqualizer(resource_grid, stream_management, dtype=tf.complex64, **kwargs)
MF equalization for OFDM MIMO transmissions.
This layer computes matched filter (MF) equalization
for OFDM MIMO transmissions. The OFDM and stream configuration are provided
by a :class:`~sionna.ofdm.ResourceGrid` and
:class:`~sionna.mimo.StreamManagement` instance, respectively. The
detection algorithm is the :meth:`~sionna.mimo.mf_equalizer`. The layer
computes soft-symbol estimates together with effective noise variances
for all streams which can, e.g., be used by a
:class:`~sionna.mapping.Demapper` to obtain LLRs.
Parameters
----------
resource_grid : ResourceGrid
An instance of :class:`~sionna.ofdm.ResourceGrid`.
stream_management : StreamManagement
An instance of :class:`~sionna.mimo.StreamManagement`.
dtype : tf.Dtype
Datatype for internal calculations and the output dtype.
Defaults to `tf.complex64`.
Input
-----
(y, h_hat, err_var, no) :
Tuple:
y : [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size], tf.complex
Received OFDM resource grid after cyclic prefix removal and FFT
h_hat : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm_symbols, num_effective_subcarriers], tf.complex
Channel estimates for all streams from all transmitters
err_var : [Broadcastable to shape of ``h_hat``], tf.float
Variance of the channel estimation error
no : [batch_size, num_rx, num_rx_ant] (or only the first n dims), tf.float
Variance of the AWGN
Output
------
x_hat : [batch_size, num_tx, num_streams, num_data_symbols], tf.complex
Estimated symbols
no_eff : [batch_size, num_tx, num_streams, num_data_symbols], tf.float
Effective noise variance for each estimated symbol
Note
----
If you want to use this layer in Graph mode with XLA, i.e., within
a function that is decorated with ``@tf.function(jit_compile=True)``,
you must set ``sionna.Config.xla_compat=true``.
See :py:attr:`~sionna.Config.xla_compat`.
"""
def __init__(self,
resource_grid,
stream_management,
dtype=tf.complex64,
**kwargs):
super().__init__(equalizer=mf_equalizer,
resource_grid=resource_grid,
stream_management=stream_management,
dtype=dtype, **kwargs)