#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Layer mapping for the 5G NR sub-package of the Sionna library.
"""
import tensorflow as tf
from tensorflow.keras.layers import Layer
from sionna.utils import flatten_last_dims, split_dim
[docs]
class LayerMapper(Layer):
# pylint: disable=line-too-long
r"""LayerMapper(num_layers=1, verbose=False, **kwargs)
Performs MIMO layer mapping of modulated symbols to layers as defined in
[3GPP38211]_.
The LayerMapper supports PUSCH and PDSCH channels and follows the procedure
as defined in Sec. 6.3.1.3 and Sec. 7.3.1.3 in [3GPP38211]_, respectively.
As specified in Tab. 7.3.1.3.-1 [3GPP38211]_, the LayerMapper expects two
input streams for multiplexing if more than 4 layers are active (only
relevant for PDSCH).
The class inherits from the Keras layer class and can be used as layer in a
Keras model.
Parameters
----------
num_layers: int, 1 (default) | [1,...,8]
Number of MIMO layers. If
``num_layers`` >=4, a list of two inputs is expected.
verbose: bool, False (default)
If True, additional parameters are printed.
Input
-----
inputs: [...,n], or [[...,n1], [...,n2]], tf.complex
2+D tensor containing the sequence of symbols to be mapped. If
``num_layers`` >=4, a list of two inputs is expected and `n1`/`n2`
must be chosen as defined in Tab. 7.3.1.3.-1 [3GPP38211]_.
Output
------
: [...,num_layers, n/num_layers], tf.complex
2+D tensor containing the sequence of symbols mapped to the MIMO
layers.
"""
def __init__(self,
num_layers=1,
verbose=False,
**kwargs):
super().__init__(**kwargs)
assert isinstance(verbose, bool), "verbose must be bool"
self._verbose = verbose
assert num_layers in range(1,9), \
'num_layers must be between 1 and 8.'
self._num_layers = num_layers
# follow Tab. 7.3.1.3-1 from 38.211 for CW multiplexing
if self._num_layers<5:
self._num_codewords=1
elif self._num_layers==5:
self._num_codewords=2
self._num_layers0 = 2
self._num_layers1 = 3
elif self._num_layers==6:
self._num_codewords=2
self._num_layers0 = 3
self._num_layers1 = 3
elif self._num_layers==7:
self._num_codewords=2
self._num_layers0 = 3
self._num_layers1 = 4
elif self._num_layers==8:
self._num_codewords=2
self._num_layers0 = 4
self._num_layers1 = 4
else:
raise ValueError("Invalid number of layers.")
if self._verbose: # provide information about layer configuration
print("Number of layers: ", self._num_layers)
if self._num_codewords==2:
print("Dual codeword mode active and cw multiplexing as " \
"defined in Tab. 7.3.1.3-1 from 38.211 applied.")
print(f"Length of cw1/cw2: {self._num_layers0}/"\
f"{self._num_layers1} ")
#########################################
# Public methods and properties
#########################################
@property
def num_codewords(self):
"""Number of input codewords for layer mapping. Can be either 1 or 2."""
return self._num_codewords
@property
def num_layers(self):
""" Number of MIMO layers"""
return self._num_layers
@property
def num_layers0(self):
r"""Number of layers for first codeword (only relevant for
`num_codewords` =2)"""
if self._num_codewords==1:
return self._num_layers
return self._num_layers0
@property
def num_layers1(self):
r"""Number of layers for second codeword (only relevant for
`num_codewords` =2)"""
if self._num_codewords==1:
return 0 # no second stream
return self._num_layers1
def build(self, input_shapes):
"""Test input shapes for consistency."""
if self._num_codewords==1: # single cw mode
assert not isinstance(input_shapes[0], tf.TensorShape),\
"Only single input codeword expected."
assert input_shapes[-1]%self._num_layers==0,\
"Invalid input dimensions: last dimension must be a " \
"multiple of num_layers."
else: # dual cw mode
# inputs must be a list of two streams
s0 = input_shapes[0].as_list()
s1 = input_shapes[1].as_list()
assert isinstance(s0, list), \
"List of two inputs streams is expected."
assert isinstance(s1, list), \
"List of two inputs streams is expected."
assert s0[-1]%self._num_layers0==0,\
"Invalid input dimensions: last dimension of first input "\
"must be a multiple of num_layers0."
assert s1[-1]%self._num_layers1==0,\
"Invalid input dimensions: last dimension of second input "\
"must be a multiple of num_layers1."
# verify that length of tb1 and tb2 fit together
assert s0[-1]/self._num_layers0 == s1[-1]/self._num_layers1, \
f"Invalid input dimensions: length of first input must be "\
f"{self._num_layers0/self._num_layers1:.2f} of the length "\
f"of the second input."
def call(self, inputs):
"""Applies MIMO Layer mapping as defined in Sec. 6.3.1.3 and Sec.
7.3.1.3 38.211."""
if self._num_codewords==1:
s = inputs.shape[-1]
y = split_dim(inputs,(int(s/self._num_layers), self._num_layers),
axis=len(inputs.shape)-1)
else:
# for PDSCH only: support dual stream multiplexing
x0 = inputs[0]
x1 = inputs[1]
s0 = x0.shape[-1]
s1 = x1.shape[-1]
y0 = split_dim(x0,(int(s0/self._num_layers0), self._num_layers0),
axis=len(x0.shape)-1)
y1 = split_dim(x1,(int(s1/self._num_layers1), self._num_layers1),
axis=len(x1.shape)-1)
y = tf.concat([y0, y1], axis=-1)
# swap last two dimensions
y = tf.experimental.numpy.swapaxes(y, axis1=-1, axis2=-2)
return y
[docs]
class LayerDemapper(Layer):
# pylint: disable=line-too-long
r"""LayerDemapper(layer_mapper, num_bits_per_symbol=1, **kwargs)
Demaps MIMO layers to coded transport block(s) by following Sec. 6.3.1.3
and Sec. 7.3.1.3 in [3GPP38211]_.
This layer must be associated to a :class:`~sionna.nr.LayerMapper` and
performs the inverse operation.
It is assumed that ``num_bits_per_symbol`` consecutive LLRs belong to
a single symbol position. This allows to apply the LayerDemapper after
demapping symbols to LLR values.
If the layer mapper is configured for dual codeword transmission, a list of
both transport block streams is returned.
The class inherits from the Keras layer class and can be used as layer in a
Keras model.
Parameters
----------
layer_mapper: :class:`~sionna.nr.LayerMapper`
Associated LayerMapper.
num_bits_per_symbol: int, 1 (default)
Modulation order. Defines how many consecutive LLRs are associated
to the same symbol position.
Input
-----
inputs : [...,num_layers, n/num_layers], tf.float
2+D tensor containing MIMO layer data sequences.
Output
------
: [...,n], or [[...,n1], [...,n2]], tf.float
2+D tensor containing the sequence of bits after layer demapping.
If ``num_codewords`` =2, a list of two transport blocks is returned.
Note
----
As it is more convenient to apply the layer demapper after demapping
symbols to LLRs, this layer groups the input sequence into groups of
``num_bits_per_symbol`` LLRs before restoring the original symbol sequence.
This behavior can be deactivated by setting ``num_bits_per_symbol`` =1.
"""
def __init__(self,
layer_mapper,
num_bits_per_symbol=1,
**kwargs):
super().__init__(**kwargs)
assert isinstance(layer_mapper, LayerMapper), \
"layer_mapper must be LayerMapper."
self._mapper = layer_mapper
assert num_bits_per_symbol%1==0, \
"num_bits_per_symbol must be int."
self._num_bits_per_symbol = num_bits_per_symbol
def build(self, input_shapes):
"""Test input shapes for consistency."""
# check that second last dimension equals number of expected streams
num_layers = self._mapper.num_layers
assert input_shapes.as_list()[-2]==num_layers, \
"Invalid input dimension: input shape must be [...,num_layers,n]."
assert input_shapes.as_list()[-1]%self._num_bits_per_symbol==0, \
"Invalid input dimension: last dimension must be a multiple of " \
"num_bits_per_symbol."
def call(self, inputs):
"""Demaps multiple layers back to transport block stream(s)."""
# group llrs into blocks of num_bits_per_symbol values
s = inputs.shape[-1]
x = split_dim(inputs,
(int(s/self._num_bits_per_symbol),
self._num_bits_per_symbol),
axis=len(inputs.shape)-1)
# swap last dimensions
x = tf.experimental.numpy.swapaxes(x, axis1=-2, axis2=-3)
if self._mapper.num_codewords==1:
y = flatten_last_dims(x, num_dims=3)
return y
else:
# multiplex into two codewords/streams
# only relevant for PDSCH with dual codeword transmission
y0 = flatten_last_dims(x[...,:self._mapper.num_layers0,:],
num_dims=3)
y1 = flatten_last_dims(x[...,self._mapper.num_layers0:,:],
num_dims=3)
return [y0, y1]