#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Class definition and functions related to OFDM transmit precoding"""
import tensorflow as tf
from tensorflow.keras.layers import Layer
import sionna
from sionna.utils import flatten_dims
from sionna.mimo import zero_forcing_precoder
from sionna.ofdm import RemoveNulledSubcarriers
[docs]
class ZFPrecoder(Layer):
# pylint: disable=line-too-long
r"""ZFPrecoder(resource_grid, stream_management, return_effective_channel=False, dtype=tf.complex64, **kwargs)
Zero-forcing precoding for multi-antenna transmissions.
This layer precodes a tensor containing OFDM resource grids using
the :meth:`~sionna.mimo.zero_forcing_precoder`. For every
transmitter, the channels to all intended receivers are gathered
into a channel matrix, based on the which the precoding matrix
is computed and the input tensor is precoded. The layer also outputs
optionally the effective channel after precoding for each stream.
Parameters
----------
resource_grid : ResourceGrid
An instance of :class:`~sionna.ofdm.ResourceGrid`.
stream_management : StreamManagement
An instance of :class:`~sionna.mimo.StreamManagement`.
return_effective_channel : bool
Indicates if the effective channel after precoding should be returned.
dtype : tf.Dtype
Datatype for internal calculations and the output dtype.
Defaults to `tf.complex64`.
Input
-----
(x, h) :
Tuple:
x : [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size], tf.complex
Tensor containing the resource grid to be precoded.
h : [batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_ofdm, fft_size], tf.complex
Tensor containing the channel knowledge based on which the precoding
is computed.
Output
------
x_precoded : [batch_size, num_tx, num_tx_ant, num_ofdm_symbols, fft_size], tf.complex
The precoded resource grids.
h_eff : [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx, num_ofdm, num_effective_subcarriers], tf.complex
Only returned if ``return_effective_channel=True``.
The effectice channels for all streams after precoding. Can be used to
simulate perfect channel state information (CSI) at the receivers.
Nulled subcarriers are automatically removed to be compliant with the
behavior of a channel estimator.
Note
----
If you want to use this layer in Graph mode with XLA, i.e., within
a function that is decorated with ``@tf.function(jit_compile=True)``,
you must set ``sionna.Config.xla_compat=true``.
See :py:attr:`~sionna.Config.xla_compat`.
"""
def __init__(self,
resource_grid,
stream_management,
return_effective_channel=False,
dtype=tf.complex64,
**kwargs):
super().__init__(dtype=dtype, **kwargs)
assert isinstance(resource_grid, sionna.ofdm.ResourceGrid)
assert isinstance(stream_management, sionna.mimo.StreamManagement)
self._resource_grid = resource_grid
self._stream_management = stream_management
self._return_effective_channel = return_effective_channel
self._remove_nulled_scs = RemoveNulledSubcarriers(self._resource_grid)
def _compute_effective_channel(self, h, g):
"""Compute effective channel after precoding"""
# Input dimensions:
# h: [batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant,...
# ..., num_ofdm, fft_size]
# g: [batch_size, num_tx, num_ofdm_symbols, fft_size, num_tx_ant,
# ..., num_streams_per_tx]
# Transpose h to shape:
# [batch_size, num_rx, num_tx, num_ofdm, fft_size, num_rx_ant,...
# ..., num_tx_ant]
h = tf.transpose(h, [0, 1, 3, 5, 6, 2, 4])
h = tf.cast(h, g.dtype)
# Add one dummy dimension to g to be broadcastable to h:
# [batch_size, 1, num_tx, num_ofdm_symbols, fft_size, num_tx_ant,...
# ..., num_streams_per_tx]
g = tf.expand_dims(g, 1)
# Compute post precoding channel:
# [batch_size, num_rx, num_tx, num_ofdm, fft_size, num_rx_ant,...
# ..., num_streams_per_tx]
h_eff = tf.matmul(h, g)
# Permute dimensions to common format of channel tensors:
# [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,...
# ..., num_ofdm, fft_size]
h_eff = tf.transpose(h_eff, [0, 1, 5, 2, 6, 3, 4])
# Remove nulled subcarriers:
# [batch_size, num_rx, num_rx_ant, num_tx, num_streams_per_tx,...
# ..., num_ofdm, num_effective_subcarriers]
h_eff = self._remove_nulled_scs(h_eff)
return h_eff
def call(self, inputs):
x, h = inputs
# x has shape
# [batch_size, num_tx, num_streams_per_tx, num_ofdm_symbols, fft_size]
#
# h has shape
# [batch_size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_ofdm,...
# ..., fft_size]
###
### Transformations to bring h and x in the desired shapes
###
# Transpose x:
#[batch_size, num_tx, num_ofdm_symbols, fft_size, num_streams_per_tx]
x_precoded = tf.transpose(x, [0, 1, 3, 4, 2])
x_precoded = tf.cast(x_precoded, self._dtype)
# Transpose h:
# [num_tx, num_rx, num_rx_ant, num_tx_ant, num_ofdm_symbols,...
# ..., fft_size, batch_size]
h_pc = tf.transpose(h, [3, 1, 2, 4, 5, 6, 0])
# Gather desired channel for precoding:
# [num_tx, num_rx_per_tx, num_rx_ant, num_tx_ant, num_ofdm_symbols,...
# ..., fft_size, batch_size]
h_pc_desired = tf.gather(h_pc, self._stream_management.precoding_ind,
axis=1, batch_dims=1)
# Flatten dims 2,3:
# [num_tx, num_rx_per_tx * num_rx_ant, num_tx_ant, num_ofdm_symbols,...
# ..., fft_size, batch_size]
h_pc_desired = flatten_dims(h_pc_desired, 2, axis=1)
# Transpose:
# [batch_size, num_tx, num_ofdm_symbols, fft_size,...
# ..., num_streams_per_tx, num_tx_ant]
h_pc_desired = tf.transpose(h_pc_desired, [5, 0, 3, 4, 1, 2])
h_pc_desired = tf.cast(h_pc_desired, self._dtype)
###
### ZF precoding
###
x_precoded, g = zero_forcing_precoder(x_precoded,
h_pc_desired,
return_precoding_matrix=True)
# Transpose output to desired shape:
#[batch_size, num_tx, num_tx_ant, num_ofdm_symbols, fft_size]
x_precoded = tf.transpose(x_precoded, [0, 1, 4, 2, 3])
if self._return_effective_channel:
h_eff = self._compute_effective_channel(h, g)
return (x_precoded, h_eff)
else:
return x_precoded