#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
"""Layer for Convolutional Code Viterbi Decoding."""
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Layer
from sionna.fec.utils import int2bin
from sionna.fec.conv.utils import polynomial_selector, Trellis
[docs]
class ViterbiDecoder(Layer):
# pylint: disable=line-too-long
r"""ViterbiDecoder(encoder=None, gen_poly=None, rate=1/2, constraint_length=3, rsc=False, terminate=False, method='soft_llr', output_dtype=tf.float32, **kwargs)
Implements the Viterbi decoding algorithm [Viterbi]_ that returns an
estimate of the information bits for a noisy convolutional codeword.
Takes as input either LLR values (`method` = `soft_llr`) or hard bit values
(`method` = `hard`) and returns a hard decided estimation of the information
bits.
The class inherits from the Keras layer class and can be used as layer in
a Keras model.
Parameters
----------
encoder: :class:`~sionna.fec.conv.encoding.ConvEncoder`
If ``encoder`` is provided as input, the following input parameters
are not required and will be ignored: ``gen_poly``, ``rate``,
``constraint_length``, ``rsc``, ``terminate``. They will be inferred
from the ``encoder`` object itself. If ``encoder`` is `None`, the
above parameters must be provided explicitly.
gen_poly: tuple
tuple of strings with each string being a 0, 1 sequence. If `None`,
``rate`` and ``constraint_length`` must be provided.
rate: float
Valid values are 1/3 and 0.5. Only required if ``gen_poly`` is `None`.
constraint_length: int
Valid values are between 3 and 8 inclusive. Only required if
``gen_poly`` is `None`.
rsc: boolean
Boolean flag indicating whether the encoder is recursive-systematic for
given generator polynomials.
`True` indicates encoder is recursive-systematic.
`False` indicates encoder is feed-forward non-systematic.
terminate: boolean
Boolean flag indicating whether the codeword is terminated.
`True` indicates codeword is terminated to all-zero state.
`False` indicates codeword is not terminated.
method: str
Valid values are `soft_llr` or `hard`. In computing path
metrics,
`soft_llr` expects channel LLRs as input
`hard` assumes a `binary symmetric channel` (BSC) with 0/1 values are
inputs. In case of `hard`, `inputs` will be quantized to 0/1 values.
output_dtype: tf.DType
Defaults to tf.float32. Defines the output datatype of the layer.
Input
-----
inputs: [...,n], tf.float32
2+D tensor containing the (noisy) channel output symbols where `n`
denotes the codeword length
Output
------
: [...,rate*n], tf.float32
2+D tensor containing the estimates of the information bit tensor
Note
----
A full implementation of the decoder rather than a windowed approach
is used. For a given codeword of duration `T`, the path metric is
computed from time `0` to `T` and the path with optimal metric at time
`T` is selected. The optimal path is then traced back from `T` to `0`
to output the estimate of the information bit vector used to encode.
For larger codewords, note that the current method is sub-optimal
in terms of memory utilization and latency.
"""
def __init__(self,
encoder=None,
gen_poly=None,
rate=1/2,
constraint_length=3,
rsc=False,
terminate=False,
method='soft_llr',
return_info_bits=True,
output_dtype=tf.float32,
**kwargs):
super().__init__(**kwargs)
if encoder is not None:
self._gen_poly = encoder.gen_poly
self._trellis = encoder.trellis
self._terminate = encoder.terminate
else:
valid_rates = (1/2, 1/3)
valid_constraint_length = (3, 4, 5, 6, 7, 8)
if gen_poly is not None:
assert all(isinstance(poly, str) for poly in gen_poly), \
"Each polynomial must be a string."
assert all(len(poly)==len(gen_poly[0]) for poly in gen_poly), \
"Each polynomial must be of same length."
assert all(all(
char in ['0','1'] for char in poly) for poly in gen_poly),\
"Each polynomial must be a string of 0's and 1's."
self._gen_poly = gen_poly
else:
valid_rates = (1/2, 1/3)
valid_constraint_length = (3, 4, 5, 6, 7, 8)
assert constraint_length in valid_constraint_length, \
"Constraint length must be between 3 and 8."
assert rate in valid_rates, \
"Rate must be 1/3 or 1/2."
self._gen_poly = polynomial_selector(rate, constraint_length)
# init Trellis parameters
self._trellis = Trellis(self.gen_poly, rsc=rsc)
self._terminate = terminate
self._coderate_desired = 1/len(self.gen_poly)
self._mu = len(self._gen_poly[0])-1
assert method in ('soft_llr', 'hard'), \
"method must be `soft_llr` or `hard`."
# conv_k denotes number of input bit streams
# can only be 1 in current implementation
self._conv_k = self._trellis.conv_k
# conv_n denotes number of output bits for conv_k input bits
self._conv_n = self._trellis.conv_n
self._k = None
self._n = None
# num_syms denote number of encoding periods or state transitions.
self._num_syms = None
self._ni = 2**self._conv_k
self._no = 2**self._conv_n
self._ns = self._trellis.ns
self._method = method
self._return_info_bits = return_info_bits
self.output_dtype = output_dtype
# If i->j state transition emits symbol k, tf.gather with ipst_op_idx
# gathers (i,k) element from input in row j.
self.ipst_op_idx = self._mask_by_tonode()
#########################################
# Public methods and properties
#########################################
@property
def gen_poly(self):
"""Generator polynomial used by the encoder"""
return self._gen_poly
@property
def coderate(self):
"""Rate of the code used in the encoder"""
if self.terminate and self._n is None:
print("Note that, due to termination, the true coderate is lower "\
"than the returned design rate. "\
"The exact true rate is dependent on the value of n and "\
"hence cannot be computed before the first call().")
self._coderate = self._coderate_desired
elif self.terminate and self._n is not None:
k = self._coderate_desired*self._n - self._mu
self._coderate = k/self._n
return self._coderate
@property
def trellis(self):
"""Trellis object used during encoding"""
return self._trellis
@property
def terminate(self):
"""Indicates if the encoder is terminated during codeword generation"""
return self._terminate
@property
def k(self):
"""Number of information bits per codeword"""
if self._k is None:
print("Note: The value of k cannot be computed before the first " \
"call().")
return self._k
@property
def n(self):
"""Number of codeword bits"""
if self._n is None:
print("Note: The value of n cannot be computed before the first " \
"call().")
return self._n
#########################
# Utility functions
#########################
def _mask_by_tonode(self):
r"""
_Ns x _No index matrix, each element of shape (2,)
where num_ops = 2**conv_n
When applied as tf.gather index on a Ns x num_ops matrix
((i,j) denoting metric for prev_st=i and output=j)
the output is matrix sorted by next_state. Row i in output
denotes the 2 possible metrics for transition to state i.
"""
cnst = self._ns * self._ni
from_nodes_vec = tf.reshape(self._trellis.from_nodes,(cnst,))
op_idx = tf.reshape(self._trellis.op_by_tonode, (cnst,))
st_op_idx = tf.transpose(tf.stack([from_nodes_vec, op_idx]))
st_op_idx = tf.reshape(st_op_idx[None,:,:],(self._ns, self._ni, 2))
return st_op_idx
def _update_fwd(self, init_cm, bm_mat):
state_vec = tf.tile(tf.range(self._ns, dtype=tf.int32)[None,:],
[tf.shape(init_cm)[0], 1])
ipst_op_mask = tf.tile(self.ipst_op_idx[None,:], [tf.shape(init_cm)[0], 1, 1, 1])
cm_ta = tf.TensorArray(tf.float32, size=self._num_syms,
dynamic_size=False, clear_after_read=False)
tb_ta = tf.TensorArray(tf.int32, size=self._num_syms,
dynamic_size=False, clear_after_read=False)
prev_cm = init_cm
for idx in tf.range(0, self._n, self._conv_n):
sym = idx//self._conv_n
metrics_t = bm_mat[..., sym]
# Ns x No matrix- (s,j) is path_metric at state s with transition op=j
sum_metric = prev_cm[:,:,None] + metrics_t[:,None,:]
sum_metric_bytonode = tf.gather_nd(sum_metric, ipst_op_mask,
batch_dims=1)
tb_state_idx = tf.math.argmin(sum_metric_bytonode, axis=2)
tb_state_idx = tf.cast(tb_state_idx, tf.int32)
# Transition to states argmin state index
from_st_idx = tf.transpose(tf.stack([state_vec, tb_state_idx]),
perm=[1, 2, 0])
tb_states = tf.gather_nd(self._trellis.from_nodes, from_st_idx)
cum_t = tf.math.reduce_min(sum_metric_bytonode,axis=2)
cm_ta = cm_ta.write(sym, cum_t)
tb_ta = tb_ta.write(sym, tb_states)
prev_cm = cum_t
return cm_ta, tb_ta
def _op_bits_path(self, paths):
r"""
Given a path, compute the input bit stream that results in the path.
Used in call() where the input is optimal path (seq of states) such
as the path returned by _return_optimal.
"""
paths = tf.cast(paths, tf.int32)
ip_bits = tf.TensorArray(tf.int32,
size=paths.shape[-1]-1,
dynamic_size=False,
clear_after_read=False)
dec_syms = tf.TensorArray(tf.int32,
size=paths.shape[-1]-1,
dynamic_size=False,
clear_after_read=False)
ni = self._trellis.ni
ip_sym_mask = tf.range(ni)[None, :]
for sym in tf.range(1, paths.shape[-1]):
# gather index from paths to enable XLA
# replaces p_idx = paths[:,sym-1:sym+1]
p_idx = tf.gather(paths, [sym-1, sym], axis=-1)
dec_ = tf.gather_nd(self._trellis.op_mat, p_idx)
dec_syms = dec_syms.write(sym-1, value=dec_)
# bs x ni boolean tensor. Each row has a True and False. True
# corresponds to input_bit which produced the next state (t=sym)
match_st = tf.math.equal(
tf.gather(self._trellis.to_nodes,paths[:, sym-1]),
tf.tile(paths[:, sym][:, None], [1, 2])
)
# tf.boolean_mask throws error in XLA mode
#ip_bit = tf.boolean_mask(ip_sym_mask, match_st)
ip_bit_ = tf.where(match_st,
ip_sym_mask,
tf.zeros_like(ip_sym_mask))
ip_bit = tf.reduce_sum(ip_bit_, axis=-1)
ip_bits = ip_bits.write(sym-1, ip_bit)
ip_bit_vec_est = tf.transpose(ip_bits.stack())
ip_sym_vec_est = tf.transpose(dec_syms.stack())
return ip_bit_vec_est, ip_sym_vec_est
def _optimal_path(self, cm_, tb_):
r"""
Compute optimal path (state at each time t) given tensors cm_ & tb_
of shapes (None, Ns, T). Output is of shape (None, T)
cm_: cumulative metrics for each state at time t(0 to T)
tb_: traceback state for each state at time t(0 to T)
"""
# tb and ca are of shape (batch x self._ns x num_syms)
assert(tb_.get_shape()[1] == self._ns), "Invalid shape."
optst_ta = tf.TensorArray(tf.int32, size=tb_.shape[-1],
dynamic_size=False,
clear_after_read=False)
if self._terminate:
opt_term_state = tf.zeros((tf.shape(cm_)[0],), tf.int32)
else:
opt_term_state =tf.cast(tf.argmin(cm_[:, :, -1], axis=1), tf.int32)
optst_ta = optst_ta.write(tb_.shape[-1]-1,opt_term_state)
for sym in tf.range(tb_.shape[-1]-1, 0, -1):
opt_st = optst_ta.read(sym)[:,None]
idx_ = tf.concat([tf.range(tf.shape(cm_)[0])[:,None], opt_st],
axis=1)
opt_st_tminus1 = tf.gather_nd(tb_[:, :, sym], idx_)
optst_ta = optst_ta.write(sym-1, opt_st_tminus1)
return tf.transpose(optst_ta.stack())
def _bmcalc(self, y):
"""
Calculate branch metrics for a given noisy codeword tensor.
For each time period t, _bmcalc computes the distance of symbol
vector y[t] from each possible output symbol.
The distance metric is L2 distance if decoder parameter `method` is
"soft".
The distance metric is L1 distance if parameter `method` is "hard".
"""
op_bits = np.stack(
[int2bin(op, self._conv_n) for op in range(self._no)])
op_mat = tf.cast(tf.tile(op_bits, [1,self._num_syms]), tf.float32)
op_mat = tf.expand_dims(op_mat, axis=0)
y = tf.expand_dims(y, axis=1)
if self._method=='soft_llr':
op_mat_sign = 1 - 2.*op_mat
llr_sign = -1. * tf.math.multiply(y, op_mat_sign)
llr_sign = tf.reshape(llr_sign,
(-1, self._no, self._num_syms, self._conv_n))
# Sum of LLR*(sign of bit) for each symbol
bm = tf.math.reduce_sum(llr_sign, axis=-1)
else: # method == 'hard'
diffabs = tf.math.abs(y-op_mat)
diffabs = tf.reshape(diffabs,
(-1, self._no, self._num_syms, self._conv_n))
# Manhattan distance of symbols
bm = tf.math.reduce_sum(diffabs, axis=-1)
return bm
#########################
# Keras layer functions
#########################
def build(self, input_shape):
"""Build layer and check dimensions."""
# assert rank must be two
tf.debugging.assert_greater_equal(len(input_shape), 2)
self._n = input_shape[-1]
divisible = tf.math.floormod(self._n, self._conv_n)
assert divisible==0, 'length of codeword should be divisible by \
number of output bits per symbol.'
self._num_syms = int(self._n*self._coderate_desired)
self._num_term_syms = self._mu if self.terminate else 0
self._k = self._num_syms - self._num_term_syms
def call(self, inputs):
"""
Viterbi decoding function.
inputs is the (noisy) codeword tensor where the last dimension should
equal n. All the leading dimensions are assumed as batch dimensions.
"""
LARGEDIST = 2.**20 # pylint: disable=invalid-name
tf.debugging.assert_type(inputs, tf.float32,
message="input must be tf.float32.")
if self._method == 'hard':
inputs = tf.math.floormod(tf.cast(inputs, tf.int32),2)
elif self._method == 'soft_llr':
inputs = -1. * inputs
inputs = tf.cast(inputs, tf.float32)
output_shape = inputs.get_shape().as_list()
y_resh = tf.reshape(inputs, [-1, self._n])
output_shape[0] = -1
if self._return_info_bits:
output_shape[-1] = self._k # assign k to the last dimension
else:
output_shape[-1] = self._n
# Branch metrics matrix for a given y
bm_mat = self._bmcalc(y_resh)
init_cm_np = np.full((self._ns,), LARGEDIST)
init_cm_np[0] = 0.0
prev_cm_ = tf.convert_to_tensor(init_cm_np, dtype=tf.float32)
prev_cm = tf.tile(prev_cm_[None,:], [tf.shape(y_resh)[0], 1])
cm_ta, tb_ta = self._update_fwd(prev_cm, bm_mat)
cm = tf.transpose(cm_ta.stack(), perm=[1,2,0])
tb = tf.transpose(tb_ta.stack(),perm=[1,2,0])
del cm_ta, tb_ta
zero_st = tf.zeros((tf.shape(y_resh)[0], 1), tf.int32)
opt_path = self._optimal_path(cm, tb)
opt_path = tf.concat((zero_st, opt_path), axis=1)
del cm, tb
msghat, cwhat = self._op_bits_path(opt_path)
if self._return_info_bits:
msghat = msghat[...,:self._k]
output = tf.cast(msghat, self.output_dtype)
else:
output = tf.cast(cwhat, self.output_dtype)
output_reshaped = tf.reshape(output, output_shape)
return output_reshaped
[docs]
class BCJRDecoder(Layer):
# pylint: disable=line-too-long
r"""BCJRDecoder(encoder=None, gen_poly=None, rate=1/2, constraint_length=3, rsc=False, terminate=False, hard_out=True, algorithm='map', output_dtype=tf.float32, **kwargs)
Implements the BCJR decoding algorithm [BCJR]_ that returns an
estimate of the information bits for a noisy convolutional codeword.
Takes as input either channel LLRs or a tuple
(channel LLRs, apriori LLRs). Returns an estimate of the information
bits, either output LLRs ( ``hard_out`` = `False`) or hard decoded
bits ( ``hard_out`` = `True`), respectively.
The class inherits from the Keras layer class and can be used as layer in
a Keras model.
Parameters
----------
encoder: :class:`~sionna.fec.conv.encoding.ConvEncoder`
If ``encoder`` is provided as input, the following input parameters
are not required and will be ignored: ``gen_poly``, ``rate``,
``constraint_length``, ``rsc``, ``terminate``. They will be inferred
from the ``encoder`` object itself. If ``encoder`` is `None`, the
above parameters must be provided explicitly.
gen_poly: tuple
tuple of strings with each string being a 0, 1 sequence. If `None`,
``rate`` and ``constraint_length`` must be provided.
rate: float
Valid values are 1/3 and 1/2. Only required if ``gen_poly`` is `None`.
constraint_length: int
Valid values are between 3 and 8 inclusive. Only required if
``gen_poly`` is `None`.
rsc: boolean
Boolean flag indicating whether the encoder is recursive-systematic for
given generator polynomials. `True` indicates encoder is
recursive-systematic. `False` indicates encoder is feed-forward non-systematic.
terminate: boolean
Boolean flag indicating whether the codeword is terminated.
`True` indicates codeword is terminated to all-zero state.
`False` indicates codeword is not terminated.
hard_out: boolean
Boolean flag indicating whether to output hard or soft decisions on
the decoded information vector.
`True` implies a hard-decoded information vector of 0/1's as output.
`False` implies output is decoded LLR's of the information.
algorithm: str
Defaults to `map`. Indicates the implemented BCJR algorithm,
where `map` denotes the exact MAP algorithm, `log` indicates the
exact MAP implementation, but in log-domain, and
`maxlog` indicates the approximated MAP implementation in log-domain,
where :math:`\log(e^{a}+e^{b}) \sim \max(a,b)`.
output_dtype: tf.DType
Defaults to tf.float32. Defines the output datatype of the layer.
Input
-----
llr_ch or (llr_ch, llr_a) :
Tensor or Tuple:
llr_ch: [...,n], tf.float32
2+D tensor containing the (noisy) channel
LLRs, where `n` denotes the codeword length
llr_a: [...,k], tf.float32
2+D tensor containing the a priori information of each information bit.
Implicitly assumed to be 0 if only ``llr_ch`` is provided.
Output
------
: tf.float32
2+D tensor of shape `[...,coderate*n]` containing the estimates of the
information bit tensor
"""
def __init__(self,
encoder=None,
gen_poly=None,
rate=1/2,
constraint_length=3,
rsc=False,
terminate=False,
hard_out=True,
algorithm='map',
output_dtype=tf.float32,
**kwargs):
super().__init__(**kwargs)
if encoder is not None:
self._gen_poly = encoder.gen_poly
self._trellis = encoder.trellis
self._terminate = encoder.terminate
else:
if gen_poly is not None:
assert all(isinstance(poly, str) for poly in gen_poly), \
"Each polynomial must be a string."
assert all(len(poly)==len(gen_poly[0]) for poly in gen_poly), \
"Each polynomial must be of same length."
assert all(all(
char in ['0','1'] for char in poly) for poly in gen_poly),\
"Each polynomial must be a string of 0's and 1's."
self._gen_poly = gen_poly
else:
valid_rates = (1/2, 1/3)
valid_constraint_length = (3, 4, 5, 6, 7, 8)
assert constraint_length in valid_constraint_length, \
"Constraint length must be between 3 and 8."
assert rate in valid_rates, \
"Rate must be 1/3 or 1/2."
self._gen_poly = polynomial_selector(rate, constraint_length)
# init Trellis parameters
self._trellis = Trellis(self.gen_poly, rsc=rsc)
self._terminate = terminate
valid_algorithms = ['map', 'log', 'maxlog']
assert algorithm in valid_algorithms, \
"algorithm must be one of map, log or maxlog"
self._coderate_desired = 1/len(self._gen_poly)
self._mu = len(self._gen_poly[0])-1
self._num_term_bits = None
self._num_term_syms = None
# conv_k denotes number of input bit streams
# can only be 1 in current implementation
self._conv_k = self._trellis.conv_k
assert self._conv_k == 1
self._mu = self._trellis._mu
# conv_n denotes number of output bits for conv_k input bits
self._conv_n = self._trellis.conv_n
# Length of Info-bit vector. Equal to _num_syms if terminate=False,
# else < _num_syms
self._k = None
# Length of Turbo codeword, including termination bits
self._n = None
# num_syms denote number of encoding periods or state transitions.
self._num_syms = None
self._ni = 2**self._conv_k
self._no = 2**self._conv_n
self._ns = self._trellis.ns
self._hard_out = hard_out
self._algorithm = algorithm
self._output_dtype = output_dtype
self.ipst_op_idx, self.ipst_ip_idx = self._mask_by_tonode()
#########################################
# Public methods and properties
#########################################
@property
def gen_poly(self):
"""Generator polynomial used by the encoder"""
return self._gen_poly
@property
def coderate(self):
"""Rate of the code used in the encoder"""
if self.terminate and self._n is None:
print("Note that, due to termination, the true coderate is lower "\
"than the returned design rate. "\
"The exact true rate is dependent on the value of n and "\
"hence cannot be computed before the first call().")
self._coderate = self._coderate_desired
elif self.terminate and self._n is not None:
k = self._coderate_desired*self._n - self._mu
self._coderate = k/self._n
return self._coderate
@property
def trellis(self):
"""Trellis object used during encoding"""
return self._trellis
@property
def terminate(self):
"""Indicates if the encoder is terminated during codeword generation"""
return self._terminate
@property
def k(self):
"""Number of information bits per codeword"""
if self._k is None:
print("Note: The value of k cannot be computed before the first " \
"call().")
return self._k
@property
def n(self):
"""Number of codeword bits"""
if self._n is None:
print("Note: The value of n cannot be computed before the first " \
"call().")
return self._n
#########################
# Utility functions
#########################
def _mask_by_tonode(self):
"""
Assume i->j a valid state transition given info-bit b & emits symbol k
returns following two _ns x _no matrices, each element of shape (2,).
- st_op_idx: jth row contains (i,k) tuples
- st_ip_idx: jth row contains (i,b) tuples
When applied as tf.gather on a _ns x _no matrix, the output is
matrix sorted by next_state.
For e.g., tf.gather when applied on "input" (shape _ns x _no), with mask
- st_op_idx: gathers input[i][k] in row j,
- st_ip_idx: gathers input[i][b] in row j.
"""
cnst = self._ns * self._ni
from_nodes_vec = tf.reshape(self._trellis.from_nodes,(cnst,))
op_idx = tf.reshape(self._trellis.op_by_tonode, (cnst,))
st_op_idx = tf.transpose(tf.stack([from_nodes_vec, op_idx]))
st_op_idx = tf.reshape(st_op_idx[None,:,:],(self._ns, self._ni, 2))
ip_idx = tf.reshape(self._trellis.ip_by_tonode, (cnst,))
st_ip_idx = tf.transpose(tf.stack([from_nodes_vec, ip_idx]))
st_ip_idx = tf.reshape(st_ip_idx[None,:,:],(self._ns, self._ni, 2))
return st_op_idx, st_ip_idx
def _bmcalc(self, llr_in):
"""
Calculate branch gamma metrics for a given noisy codeword tensor.
For each time period t, _bmcalc computes the "distance" of symbol
vector y[t] from each possible output symbol i.e.,
(2*Eb/N0)* sum_i x_y*y_i for i=1,2,...,conv_n
The above metric is used in calculation of gamma.
If the input is llr, which is nothing but 2*Eb*y/N0.
"""
op_bits = np.stack(
[int2bin(op, self._conv_n) for op in range(self._no)])
op_mat = tf.cast(tf.tile(op_bits, [1, self._num_syms]), tf.float32)
op_mat = tf.expand_dims(op_mat, axis=0)
llr_in = tf.expand_dims(llr_in, axis=1)
op_mat_sign = 1. - 2. * op_mat
llr_sign = tf.math.multiply(llr_in, op_mat_sign)
half_llr_sign = tf.reshape(0.5 * llr_sign,
(-1, self._no, self._num_syms, self._conv_n))
if self._algorithm in ['log', 'maxlog']:
bm = tf.math.reduce_sum(half_llr_sign, axis=-1)
else:
bm = tf.math.exp(tf.math.reduce_sum(half_llr_sign, axis=-1))
return bm
def _initialize(self, llr_ch):
if self._algorithm in ['log', 'maxlog']:
init_vals = -np.inf, 0.0
else:
init_vals = 0.0, 1.0
alpha_init_np = np.full((self._ns,), init_vals[0])
alpha_init_np[0] = init_vals[1]
beta_init_np = alpha_init_np
if not self._terminate:
eq_prob = 1./self._ns
if self._algorithm in ['log', 'maxlog']:
eq_prob = np.log(eq_prob)
beta_init_np = np.full((self._ns,), eq_prob)
alpha_init = tf.convert_to_tensor(alpha_init_np, dtype=tf.float32)
alpha_init = tf.tile(alpha_init[None,:], [tf.shape(llr_ch)[0], 1])
beta_init = tf.convert_to_tensor(beta_init_np, dtype=tf.float32)
beta_init = tf.tile(beta_init[None,:], [tf.shape(llr_ch)[0], 1])
return alpha_init, beta_init
def _update_fwd(self, alph_init, bm_mat, llr):
"""
Run forward update from time t=0 to t=k-1.
At each time t, computes alpha_t using alpha_t-1 and gamma_t.
Returns tensor array of alpha_t, t-0,1,2...,k-1
"""
alph_ta = tf.TensorArray(tf.float32, size=self._num_syms+1,
dynamic_size=False, clear_after_read=False)
alph_prev = tf.cast(alph_init, tf.float32)
# (bs, _Ns, _ni, 2) matrix
ipst_ip_mask = tf.tile(
self.ipst_ip_idx[None,:],[tf.shape(alph_init)[0],1,1,1])
# (bs, _Ns, _ni) matrix, by from state
op_mask = tf.tile(self.trellis.op_by_fromnode[None,:,:],
[tf.shape(alph_init)[0],1,1])
ipbit_mat = tf.tile(tf.range(self._ni)[None, None, :],
[tf.shape(alph_init)[0], self._ns, 1])
ipbitsign_mat = 1. - 2. * tf.cast(ipbit_mat, tf.float32)
alph_ta = alph_ta.write(0, alph_prev)
for t in tf.range(self._num_syms):
bm_t = bm_mat[..., t]
llr_t = 0.5 * llr[...,t][:, None,None]
bm_byfromst = tf.gather(bm_t, op_mask, batch_dims=1)
signed_half_llr = tf.math.multiply(
tf.tile(llr_t,[1, self._ns, self._ni]), ipbitsign_mat)
if self._algorithm in ['log', 'maxlog']:
llr_byfromst = signed_half_llr
gamma_byfromst = llr_byfromst + bm_byfromst
alph_gam_prod = gamma_byfromst + alph_prev[:,:,None]
else:
llr_byfromst = tf.math.exp(signed_half_llr)
gamma_byfromst = tf.multiply(llr_byfromst, bm_byfromst)
alph_gam_prod = tf.math.multiply(gamma_byfromst,
alph_prev[:,:,None])
alphgam_bytost = tf.gather_nd(alph_gam_prod,
ipst_ip_mask,
batch_dims=1)
if self._algorithm =='map':
alph_t = tf.math.reduce_sum(alphgam_bytost, axis=-1)
alph_t_sum = tf.reduce_sum(alph_t, axis=-1)
alph_t = tf.divide(alph_t, tf.tile(alph_t_sum[:,None],[1,self._ns]))
elif self._algorithm == 'log':
alph_t = tf.math.reduce_logsumexp(alphgam_bytost, axis=-1)
else: # self._algorithm = 'maxlog'
alph_t = tf.math.reduce_max(alphgam_bytost, axis=-1)
alph_prev = alph_t
alph_ta = alph_ta.write(t+1, alph_t)
return alph_ta
def _update_bwd(self, beta_init, bm_mat, llr, alpha_ta):
"""
Run backward update from time t=k-1 to t=0.
At each time t, computes beta_t-1 using beta_t and gamma_t.
Returns llr for information bits for t=0,1,...,k-1
"""
beta_next = beta_init
llr_op_ta = tf.TensorArray(tf.float32,
size=self._num_syms,
dynamic_size=False,
clear_after_read=False)
beta_next = tf.cast(beta_next, tf.float32)
# (bs, _Ns, _ni) matrix, by from state
op_mask = tf.tile(self.trellis.op_by_fromnode[None,:,:],
[tf.shape(beta_init)[0],1,1])
tonode_mask = tf.tile(self.trellis.to_nodes[None,:,:],
[tf.shape(beta_init)[0], 1, 1])
ipbit_mat = tf.tile(tf.range(self._ni)[None, None, :],
[tf.shape(beta_init)[0], self._ns, 1])
ipbitsign_mat = 1.0 - 2.0 * tf.cast(ipbit_mat, tf.float32)
for t in tf.range(self._num_syms-1, -1, -1):
bm_t = bm_mat[..., t]
llr_t = 0.5 * llr[...,t][:, None,None]
signed_half_llr = tf.math.multiply(
tf.tile(llr_t,[1, self._ns, self._ni]), ipbitsign_mat)
bm_byfromst = tf.gather(bm_t, op_mask, batch_dims=1)
if self._algorithm in ['log', 'maxlog']:
llr_byfromst = signed_half_llr
gamma_byfromst = tf.math.add(llr_byfromst, bm_byfromst)
else:
llr_byfromst = tf.math.exp(signed_half_llr)
gamma_byfromst = tf.multiply(llr_byfromst, bm_byfromst)
beta_bytonode = tf.gather(beta_next, tonode_mask, batch_dims=1)
if self._algorithm not in ['log', 'maxlog']:
beta_gam_prod = tf.math.multiply(gamma_byfromst, beta_bytonode)
beta_t = tf.math.reduce_sum(beta_gam_prod, axis=-1)
beta_t_sum = tf.reduce_sum(beta_t, axis=-1)
beta_t = tf.divide(beta_t, tf.tile(beta_t_sum[:,None],[1,self._ns]))
elif self._algorithm == 'log':
beta_gam_prod = gamma_byfromst + beta_bytonode
beta_t = tf.math.reduce_logsumexp(beta_gam_prod, axis=-1, keepdims=False)
else: #self._algorithm = 'maxlog'
beta_gam_prod = gamma_byfromst + beta_bytonode
beta_t = tf.math.reduce_max(beta_gam_prod, axis=-1)
alph_t = alpha_ta.read(t)
if self._algorithm not in ['log', 'maxlog']:
llr_op_t0 = tf.math.multiply(
tf.math.multiply(alph_t, gamma_byfromst[...,0]),
beta_bytonode[...,0])
llr_op_t1 = tf.math.multiply(
tf.math.multiply(alph_t,gamma_byfromst[...,1]),
beta_bytonode[...,1])
llr_op_t = tf.math.log(tf.divide(tf.reduce_sum(llr_op_t0, axis=-1),
tf.reduce_sum(llr_op_t1,axis=-1)))
else:
llr_op_t0 = alph_t + gamma_byfromst[...,0] + beta_bytonode[...,0]
llr_op_t1 = alph_t + gamma_byfromst[...,1] + beta_bytonode[...,1]
if self._algorithm == 'log':
llr_op_t = tf.math.subtract(
tf.math.reduce_logsumexp(llr_op_t0, axis=-1),
tf.math.reduce_logsumexp(llr_op_t1, axis=-1))
else:
llr_op_t = tf.math.subtract(
tf.math.reduce_max(llr_op_t0, axis=-1),
tf.math.reduce_max(llr_op_t1, axis=-1))
llr_op_ta = llr_op_ta.write(t, llr_op_t)
beta_next = beta_t
llr_op = tf.transpose(llr_op_ta.stack())
return llr_op
#########################
# Keras layer functions
#########################
def build(self, input_shape):
"""Build layer and check dimensions."""
# assert rank must be two
tf.debugging.assert_greater_equal(len(input_shape), 2)
if isinstance(input_shape, tf.TensorShape):
self._n = input_shape[-1]
else:
self._n = input_shape[0][-1]
self._num_syms = int(self._n*self._coderate_desired)
self._num_term_syms = self._mu if self._terminate else 0
self._num_term_bits = int(self._num_term_syms/self._coderate_desired)
self._k = self._num_syms - self._num_term_syms
def call(self, inputs):
"""
BCJR decoding function.
inputs is the (noisy) codeword tensor where the last dimension should
equal n. All the leading dimensions are assumed as batch dimensions.
"""
if isinstance(inputs, (tuple, list)):
assert(len(inputs)) == 2
llr_ch, llr_apr = inputs
else:
tf.debugging.assert_greater(tf.rank(inputs), 1)
llr_ch = inputs
llr_apr = None
tf.debugging.assert_type(llr_ch,
tf.float32,
message="input must be tf.float32.")
output_shape = llr_ch.get_shape().as_list()
# allow different codeword lengths in eager mode
if output_shape[-1] != self._n:
if isinstance(inputs, (tuple, list)):
self.build((inputs[0].get_shape(),
inputs[1].get_shape()))
else:
self.build(llr_ch.get_shape().as_list())
output_shape[0] = -1
output_shape[-1] = self._k # assign k to the last dimension
llr_ch = tf.reshape(llr_ch, [-1, self._n])
if llr_apr is None:
llr_apr = tf.zeros((tf.shape(llr_ch)[0], self._num_syms),
dtype=tf.float32)
llr_ch = -1. * llr_ch
llr_apr = -1. * llr_apr
# Branch metrics matrix for a given y
bm_mat = self._bmcalc(llr_ch)
alpha_init, beta_init = self._initialize(llr_ch)
alph_ta = self._update_fwd(alpha_init, bm_mat, llr_apr)
llr_op = self._update_bwd(beta_init, bm_mat, llr_apr, alph_ta)
msghat = -1. * llr_op[...,:self._k]
if self._hard_out: # hard decide decoder output if required
msghat = tf.less(0.0, msghat)
msghat = tf.cast(msghat, self._output_dtype)
msghat_reshaped = tf.reshape(msghat, output_shape)
return msghat_reshaped