Skip to content

Commit d2e03a8

Browse files
Implement transform precoding (DFT-s-OFDM)
Signed-off-by: Daniel Schäufele <[email protected]>
1 parent 2cb12fd commit d2e03a8

18 files changed

+666
-34
lines changed

sionna/mimo/detection.py

+9
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ class LinearDetector(Layer):
5252
constellation point indices instead of soft-values.
5353
Defaults to `False`.
5454
55+
post_equalizer_transformation: None or Layer
56+
Optional layer that applies a transformation after the equalizer and
57+
before the demapper. This can be used to apply transform precoding
58+
when DFT-s-OFDM is enabled in NR PUSCH.
59+
5560
dtype : One of [tf.complex64, tf.complex128] tf.DType (dtype)
5661
The dtype of ``y``. Defaults to tf.complex64.
5762
The output dtype is the corresponding real dtype (tf.float32 or tf.float64).
@@ -96,11 +101,13 @@ def __init__(self,
96101
num_bits_per_symbol=None,
97102
constellation=None,
98103
hard_out=False,
104+
post_equalizer_transformation=None,
99105
dtype=tf.complex64,
100106
**kwargs):
101107
super().__init__(dtype=dtype, **kwargs)
102108
self._output = output
103109
self._hard_out = hard_out
110+
self._post_equalizer_transformation = post_equalizer_transformation
104111

105112
# Determine the equalizer to use
106113
if isinstance(equalizer, str):
@@ -137,6 +144,8 @@ def __init__(self,
137144

138145
def call(self, inputs):
139146
x_hat, no_eff = self._equalizer(*inputs)
147+
if self._post_equalizer_transformation is not None:
148+
x_hat = self._post_equalizer_transformation(x_hat)
140149
z = self._demapper([x_hat, no_eff])
141150

142151
# Reshape to the expected output shape

sionna/nr/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
from .pusch_dmrs_config import PUSCHDMRSConfig
1212
from .pusch_pilot_pattern import PUSCHPilotPattern
1313
from .pusch_precoder import PUSCHPrecoder
14+
from .pusch_transform_precoder import PUSCHTransformPrecoder, PUSCHTransformDeprecoder
1415
from .pusch_transmitter import PUSCHTransmitter
1516
from .pusch_receiver import PUSCHReceiver
1617
from .pusch_channel_estimation import PUSCHLSChannelEstimator
1718
from .tb_config import TBConfig
18-
from .utils import generate_prng_seq, select_mcs, calculate_tb_size
19+
from .utils import generate_prng_seq, generate_low_papr_seq_type_1, select_mcs, calculate_tb_size
1920
from .tb_encoder import TBEncoder
2021
from .tb_decoder import TBDecoder
2122
from .layer_mapping import LayerMapper, LayerDemapper

sionna/nr/pusch_config.py

+102-17
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
"""
77
# pylint: disable=line-too-long
88

9+
import functools
910
import numpy as np
10-
from .utils import generate_prng_seq
11+
from .utils import generate_prng_seq, generate_low_papr_seq_type_1
1112
from .config import Config
1213
from sionna import nr
1314
from .utils import calculate_tb_size
@@ -233,7 +234,7 @@ def n_rnti(self, value):
233234
assert value in range(65536), "n_rnti must be in [0, 65535]"
234235
self._n_rnti = value
235236

236-
#---transform_precoding---#
237+
#---precoding---#
237238
@property
238239
def precoding(self):
239240
"""
@@ -427,9 +428,9 @@ def n(self):
427428
used for DMRS generation
428429
"""
429430
if self.dmrs.config_type==1:
430-
n_max = self.num_resource_blocks*12//4 -1
431+
n_max = self.num_effective_subcarriers//4 -1
431432
elif self.dmrs.config_type==2:
432-
n_max = self.num_resource_blocks*12//6 -1
433+
n_max = self.num_effective_subcarriers//6 -1
433434
return list(range(n_max+1))
434435

435436
@property
@@ -450,6 +451,31 @@ def num_resource_blocks(self):
450451
else:
451452
return self.n_size_bwp
452453

454+
@property
455+
def num_effective_resource_blocks(self):
456+
"""
457+
int, read-only : Number of allocated resource blocks for the
458+
PUSCH transmissions, that are actually used (can differ from
459+
num_subcarriers when transform precoding is enabled,
460+
because of constraints on the largest prime factor of the
461+
subcarrier count)
462+
"""
463+
@functools.lru_cache
464+
def adjust_prbs_to_prime_factor_constraints(prbs):
465+
# Decreases the number of PRBs until the largest prime factor is at most 5
466+
for eff_prbs in range(prbs, 1, -1):
467+
n = eff_prbs
468+
for p in [2, 3, 5]:
469+
while n % p == 0:
470+
n /= p
471+
if n == 1:
472+
return eff_prbs
473+
474+
if self.transform_precoding:
475+
return adjust_prbs_to_prime_factor_constraints(self.num_resource_blocks)
476+
else:
477+
return self.num_resource_blocks
478+
453479
@property
454480
def num_subcarriers(self):
455481
"""
@@ -458,6 +484,17 @@ def num_subcarriers(self):
458484
"""
459485
return 12*self.num_resource_blocks
460486

487+
@property
488+
def num_effective_subcarriers(self):
489+
"""
490+
int, read-only : Number of allocated subcarriers for the
491+
PUSCH transmissions, that are actually used (can differ from
492+
num_subcarriers when transform precoding is enabled,
493+
because of constraints on the largest prime factor of the
494+
subcarrier count)
495+
"""
496+
return 12 * self.num_effective_resource_blocks
497+
461498
@property
462499
def num_res_per_prb(self):
463500
"""
@@ -488,7 +525,7 @@ def dmrs_mask(self):
488525
resource elements in the resource grid. `True` corresponds to
489526
resource elements on which no data is transmitted.
490527
"""
491-
mask = np.zeros([self.num_subcarriers,
528+
mask = np.zeros([self.num_effective_subcarriers,
492529
self.carrier.num_symbols_per_slot],
493530
dtype=bool)
494531

@@ -503,7 +540,7 @@ def dmrs_mask(self):
503540
cdm_ind[:,i] = np.array([0,1, 6, 7])+2*i
504541

505542
for i in self.dmrs_symbol_indices:
506-
for j in range(self.num_resource_blocks):
543+
for j in range(self.num_effective_resource_blocks):
507544
for k in range(num_cdm_groups):
508545
mask[cdm_ind[:, k] + 12*j, i] = True
509546
return mask
@@ -518,7 +555,7 @@ def dmrs_grid(self):
518555
This property returns for each configured DMRS port an empty
519556
resource grid filled with DMRS signals as defined in
520557
Section 6.4.1.1 [3GPP38211]. Not all possible options are implemented,
521-
e.g., frequency hopping and transform precoding are not available.
558+
e.g., frequency hopping is not available.
522559
523560
This property provides the *unprecoded* DMRS for each configured DMRS port.
524561
Precoding might be applied to map the DMRS to the antenna ports. However,
@@ -536,7 +573,7 @@ def dmrs_grid(self):
536573

537574
# Generate empty resource grid for each port
538575
a_tilde = np.zeros([len(self.dmrs.dmrs_port_set),
539-
self.num_subcarriers,
576+
self.num_effective_subcarriers,
540577
self.carrier.num_symbols_per_slot],
541578
dtype=complex)
542579

@@ -546,15 +583,23 @@ def dmrs_grid(self):
546583
# For every l_prime
547584
for l_prime in self.l_prime:
548585

549-
# Compute c_init
550586
l = l_bar + l_prime
551-
c_init = self.c_init(l)
552587

553-
# Generate RNG
554-
c = generate_prng_seq(2*self.num_subcarriers, c_init=c_init)
588+
if self.transform_precoding:
589+
if self.dmrs.n_sid is None:
590+
n_id = self.carrier.n_cell_id
591+
else:
592+
n_id = self.dmrs.n_sid
593+
r = generate_low_papr_seq_type_1(self.num_effective_subcarriers // 2, n_id % 30, 0, 0)
594+
else:
595+
# Compute c_init
596+
c_init = self.c_init(l)
597+
598+
# Generate RNG
599+
c = generate_prng_seq(2*self.num_effective_subcarriers, c_init=c_init)
555600

556-
# Map to QAM
557-
r = 1/np.sqrt(2)*((1-2*c[::2]) + 1j*(1-2*c[1::2]))
601+
# Map to QAM
602+
r = 1/np.sqrt(2)*((1-2*c[::2]) + 1j*(1-2*c[1::2]))
558603

559604
# For every port in the dmrs port set
560605
for j_ind, _ in enumerate(self.dmrs.dmrs_port_set):
@@ -625,8 +670,38 @@ def precoding_matrix(self):
625670

626671
w /= np.sqrt(2)
627672

673+
# Table 6.3.1.5-2
674+
elif self.transform_precoding and self.num_antenna_ports == 4:
675+
w = np.zeros([28, 4, 1], complex)
676+
677+
# TPMI index 0-7
678+
w[:8,0,0] = [ 1, 0, 0, 0, 1, 1, 1, 1]
679+
w[:8,1,0] = [ 0, 1, 0, 0, 0, 0, 0, 0]
680+
w[:8,2,0] = [ 0, 0, 1, 0, 1, -1, 1j,-1j]
681+
w[:8,3,0] = [ 0, 0, 0, 1, 0, 0, 0, 0]
682+
683+
# TPMI index 8-15
684+
w[8:16,0,0] = [ 0, 0, 0, 0, 1, 1, 1, 1]
685+
w[8:16,1,0] = [ 1, 1, 1, 1, 1, 1, 1, 1]
686+
w[8:16,2,0] = [ 0, 0, 0, 0, 1, 1j, -1,-1j]
687+
w[8:16,3,0] = [ 1, -1, 1j,-1j, -1, 1j, 1,-1j]
688+
689+
# TPMI index 16-23
690+
w[16:24,0,0] = [ 1, 1, 1, 1, 1, 1, 1, 1]
691+
w[16:24,1,0] = [ 1j, 1j, 1j, 1j, -1, -1, -1, -1]
692+
w[16:24,2,0] = [ 1, 1j, -1,-1j, 1, 1j, -1,-1j]
693+
w[16:24,3,0] = [ 1j, 1,-1j, -1, 1,-1j, -1, 1j]
694+
695+
# TPMI index 24-27
696+
w[24:28,0,0] = [ 1, 1, 1, 1]
697+
w[24:28,1,0] = [-1j,-1j,-1j,-1j]
698+
w[24:28,2,0] = [ 1, 1j, -1,-1j]
699+
w[24:28,3,0] = [-1j, -1, 1j, 1]
700+
701+
w /= 2
702+
628703
# Table 6.3.1.5-3
629-
elif self.num_antenna_ports==4:
704+
elif not self.transform_precoding and self.num_antenna_ports==4:
630705
w = np.zeros([28,4,1], complex)
631706

632707
# TPMI index 0-7
@@ -825,7 +900,7 @@ def num_coded_bits(self):
825900
n_re_per_prb = self.num_res_per_prb - self.num_ov
826901

827902
# number of allocated REs
828-
n_re = n_re_per_prb * self.num_resource_blocks
903+
n_re = n_re_per_prb * self.num_effective_resource_blocks
829904

830905
# total number of bits per slot
831906
num_coded_bits = int(self.tb.tb_scaling * self.tb.num_bits_per_symbol \
@@ -842,7 +917,7 @@ def tb_size(self):
842917

843918
# number of allocated REs
844919
# the max. number of REs per PRB is limited to 156 in 38.214
845-
n_re = min(156, n_re_per_prb) * self.num_resource_blocks
920+
n_re = min(156, n_re_per_prb) * self.num_effective_resource_blocks
846921

847922
# include tb_scaling as defined in Tab. 5.1.3.2-2 38.214
848923
target_tb_size = int(self.tb.target_coderate * self.tb.tb_scaling \
@@ -924,6 +999,14 @@ def check_config(self):
924999
assert self.num_layers == self.num_antenna_ports,\
9251000
"num_layers must be == num_antenna_ports"
9261001

1002+
if self.transform_precoding:
1003+
assert self.num_layers == 1,\
1004+
"When transform precoding is used, only a single MIMO layer is supported"
1005+
assert self.dmrs.config_type == 1, \
1006+
"When transform precoding is used, DMRS config type must be 1"
1007+
assert self.dmrs.num_cdm_groups_without_data == 2, \
1008+
"When transform precoding is used, num_cdm_groups_without_data must be 2"
1009+
9271010
# Check Tables 6.4.1.1.3-3/4 are valid
9281011
if self.dmrs.length==1:
9291012
if self.mapping_type=="A":
@@ -1033,11 +1116,13 @@ def check_pusch_configs(pusch_configs):
10331116
"num_tx" : len(pusch_configs),
10341117
"num_layers" : pc.num_layers,
10351118
"num_subcarriers" : pc.num_subcarriers,
1119+
"num_effective_subcarriers": pc.num_effective_subcarriers,
10361120
"num_ofdm_symbols" : pc.symbol_allocation[1],
10371121
"subcarrier_spacing" : pc.carrier.subcarrier_spacing*1e3,
10381122
"num_antenna_ports" : pc.num_antenna_ports,
10391123
"precoding" : pc.precoding,
10401124
"precoding_matrices" : [],
1125+
"transform_precoding" : pc.transform_precoding,
10411126
"pusch_config" : pc,
10421127
"carrier_config" : pc.carrier,
10431128
"num_coded_bits" : pc.num_coded_bits,

sionna/nr/pusch_dmrs_config.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,29 @@ def n_id(self, value):
151151
if value is None:
152152
self._n_id = None
153153
elif isinstance(value, int):
154-
assert value in list(range(65536)), "n_id must be in [0, 65535]"
154+
assert value in range(65536), "n_id must be in [0, 65535]"
155155
self._n_id = [value, value]
156156
else:
157157
assert len(value)==2, "n_id must be either [] or a two-tuple"
158158
for e in value:
159-
assert e in list(range(65536)), "Each element of n_id must be in [0, 65535]"
159+
assert e in range(65536), "Each element of n_id must be in [0, 65535]"
160160
self._n_id = value
161161

162+
#---n_sid---#
163+
@property
164+
def n_sid(self):
165+
r"""
166+
None (default), [0,...,1007] : DMRS scrambling identity for DFT-s-OFDM
167+
:math:`n_\text{ID}^\text{PUSCH}`
168+
"""
169+
self._ifndef("n_sid", None)
170+
return self._n_sid
171+
172+
@n_sid.setter
173+
def n_sid(self, value):
174+
assert value is None or (isinstance(value, int) and value in range(1008)), "n_sid must None or in [0, 1007]"
175+
self._n_sid = value
176+
162177
#---n_scid---#
163178
@property
164179
def n_scid(self):

sionna/nr/pusch_receiver.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sionna.ofdm import OFDMDemodulator, LinearDetector
1313
from sionna.utils import insert_dims
1414
from sionna.channel import time_to_ofdm_channel
15+
from .pusch_transform_precoder import PUSCHTransformDeprecoder
1516

1617
class PUSCHReceiver(Layer):
1718
# pylint: disable=line-too-long
@@ -197,14 +198,19 @@ def __init__(self,
197198
# Use or create default MIMODetector
198199
if mimo_detector is None:
199200
# Default MIMO detector
201+
transformation = PUSCHTransformDeprecoder(pusch_transmitter.resource_grid.num_effective_subcarriers,
202+
dtype=dtype) if pusch_transmitter._transform_precoding else None
200203
self._mimo_detector = LinearDetector("lmmse", "bit", "maxlog",
201-
pusch_transmitter.resource_grid,
202-
self._stream_management,
203-
"qam",
204-
pusch_transmitter._num_bits_per_symbol,
205-
dtype=dtype)
204+
pusch_transmitter.resource_grid,
205+
self._stream_management,
206+
"qam",
207+
pusch_transmitter._num_bits_per_symbol,
208+
post_equalizer_transformation=transformation,
209+
dtype=dtype)
206210
else:
207211
# User-provided MIMO detector
212+
if pusch_transmitter._transform_precoding:
213+
print("WARNING: Using custom mimo detector which might not support transform precoding.")
208214
self._mimo_detector = mimo_detector
209215

210216
# Create LayerDemapper
@@ -248,7 +254,6 @@ def call(self, inputs):
248254
if self._input_domain=="time":
249255
h = time_to_ofdm_channel(h, self.resource_grid, self._l_min)
250256

251-
252257
if self._w is not None:
253258
# Reshape h to put channel matrix dimensions last
254259
# [batch size, num_rx, num_tx, num_ofdm_symbols,...

0 commit comments

Comments
 (0)