15
15
class OFDMDemodulator (Layer ):
16
16
# pylint: disable=line-too-long
17
17
r"""
18
- OFDMDemodulator(fft_size, l_min, cyclic_prefix_length=0, cyclic_prefix_length_first_symbol=None, symbols_per_block=1, **kwargs)
18
+ OFDMDemodulator(fft_size, l_min, cyclic_prefix_length=0, **kwargs)
19
19
20
20
Computes the frequency-domain representation of an OFDM waveform
21
21
with cyclic prefix removal.
22
22
23
- When only `cyclic_prefix_length` is given then a cyclic prefix of length
24
- `cyclic_prefix_length` is removed from each symbol. When additionally
25
- `cyclic_prefix_length_first_symbol` and `symbols_per_block` are given then
26
- the length of the cyclic prefix is `cyclic_prefix_length_first_symbol` for
27
- the first symbol of each block and `cyclic_prefix_length` for the
28
- remaining symbols. For LTE one block corresponds to one slot (i.e., 7
29
- symbols). For 5G NR one block corresponds to one half subframe and the
30
- number of symbols depends on the numerology.
31
-
32
23
The demodulator assumes that the input sequence is generated by the
33
24
:class:`~sionna.channel.TimeChannel`. For a single pair of antennas,
34
25
the received signal sequence is given as:
@@ -72,17 +63,9 @@ class OFDMDemodulator(Layer):
72
63
impulse response. It should be the same value as that used by the
73
64
`cir_to_time_channel` function.
74
65
75
- cyclic_prefix_length : int
76
- Integer indicating the length of the cyclic prefix that it prepended
77
- to each OFDM symbol (except for the first symbol of each block if
78
- `cyclic_prefix_length_first_symbol` and `symbols per block` is given).
79
-
80
- cyclic_prefix_length_first_symbol : int
81
- Integer indicating the length of the cyclic prefix that it prepended
82
- to the first OFDM symbol of each block.
83
-
84
- symbols_per_block : int
85
- Integer indicating the number of symbols per block.
66
+ cyclic_prefix_length : int or list[int] or np.ndarray[int]
67
+ Integer or vector of integers indicating the length of the cyclic
68
+ prefix that it prepended to each OFDM symbol.
86
69
87
70
Input
88
71
-----
@@ -98,15 +81,11 @@ class OFDMDemodulator(Layer):
98
81
"""
99
82
100
83
def __init__ (self , fft_size , l_min , cyclic_prefix_length = 0 ,
101
- cyclic_prefix_length_first_symbol = None , symbols_per_block = 1 ,
102
84
** kwargs ):
103
85
super ().__init__ (** kwargs )
104
86
self .fft_size = fft_size
105
87
self .l_min = l_min
106
88
self .cyclic_prefix_length = cyclic_prefix_length
107
- self .cyclic_prefix_length_first_symbol = (
108
- cyclic_prefix_length_first_symbol )
109
- self .symbols_per_block = symbols_per_block
110
89
111
90
@property
112
91
def fft_size (self ):
@@ -133,34 +112,18 @@ def cyclic_prefix_length(self):
133
112
134
113
@cyclic_prefix_length .setter
135
114
def cyclic_prefix_length (self , value ):
136
- assert isinstance (value , int ) and value >= 0 ,\
137
- "`cyclic_prefix_length` must be a nonnegative integer."
138
- self ._cyclic_prefix_length = int (value )
139
-
140
- @property
141
- def cyclic_prefix_length_first_symbol (self ):
142
- if self ._cyclic_prefix_length_first_symbol is None :
143
- return self ._cyclic_prefix_length
115
+ if isinstance (value , list ):
116
+ value = np .array (value )
117
+ if isinstance (value , np .ndarray ):
118
+ assert (np .issubdtype (value .dtype , np .integer ) and
119
+ value .ndim == 1 and np .all (value >= 0 )),\
120
+ ("`cyclic_prefix_length` must be a 1D array with"
121
+ " only nonnegative integers." )
144
122
else :
145
- return self ._cyclic_prefix_length_first_symbol
123
+ assert isinstance (value , int ) and value >= 0 ,\
124
+ "`cyclic_prefix_length` must be a nonnegative integer."
125
+ self ._cyclic_prefix_length = value
146
126
147
- @cyclic_prefix_length_first_symbol .setter
148
- def cyclic_prefix_length_first_symbol (self , value ):
149
- assert (value is None or isinstance (value , int ) and
150
- value >= self ._cyclic_prefix_length ),\
151
- ("`cyclic_prefix_length_first_symbol` must be integer and " +
152
- "larger or equal to `cyclic_prefix_length`." )
153
- self ._cyclic_prefix_length_first_symbol = value
154
-
155
- @property
156
- def symbols_per_block (self ):
157
- return self ._symbols_per_block
158
-
159
- @symbols_per_block .setter
160
- def symbols_per_block (self , value ):
161
- assert isinstance (value , int ) and value >= 1 ,\
162
- "`symbols_per_block` must be a positive integer."
163
- self ._symbols_per_block = value
164
127
165
128
def build (self , input_shape ):
166
129
num_samples = input_shape [- 1 ]
@@ -170,24 +133,23 @@ def build(self, input_shape):
170
133
* tf .range (self .fft_size , dtype = tf .float32 )
171
134
self ._phase_compensation = tf .exp (tf .complex (0. , tmp ))
172
135
173
- self ._samples_per_block = (self .cyclic_prefix_length_first_symbol +
174
- (self .symbols_per_block - 1 ) * self .cyclic_prefix_length +
175
- self .symbols_per_block * self .fft_size )
176
-
177
- # Compute number of elements that will be truncated and number of
178
- # symbols for padding
179
- self ._rest = num_samples % self ._samples_per_block
180
- samples_first_symbol = (self .cyclic_prefix_length_first_symbol +
181
- self .fft_size )
182
- samples_other_symbols = (self .cyclic_prefix_length + self .fft_size )
183
- if self ._rest > samples_first_symbol :
184
- self ._rest -= samples_first_symbol
185
- excess_symbols = self ._rest // samples_other_symbols
186
- self ._rest -= excess_symbols * samples_other_symbols
187
- excess_symbols += 1 # Because of first symbol in block
188
- self ._num_pad_symbols = self .symbols_per_block - excess_symbols
136
+ if isinstance (self .cyclic_prefix_length , int ):
137
+ self ._num_ofdm_symbols = (input_shape [- 1 ] //
138
+ (self .fft_size + self .cyclic_prefix_length ))
139
+ self .cyclic_prefix_length = np .full (self ._num_ofdm_symbols ,
140
+ self .cyclic_prefix_length )
189
141
else :
190
- self ._num_pad_symbols = 0
142
+ self ._num_ofdm_symbols = self .cyclic_prefix_length .shape [0 ]
143
+
144
+ symbol_ends = tf .math .cumsum (self .cyclic_prefix_length + self .fft_size )
145
+ assert num_samples >= symbol_ends [- 1 ],\
146
+ "shape(inputs)[-1] must be larger or equal than samples per slot"
147
+
148
+ gather_idx = []
149
+ for i in range (self ._num_ofdm_symbols ):
150
+ gather_idx .append (tf .range (symbol_ends [i ] - self .fft_size ,
151
+ symbol_ends [i ]))
152
+ self ._gather_idx = tf .concat (gather_idx , 0 )
191
153
192
154
def call (self , inputs ):
193
155
"""Demodulate OFDM waveform onto a resource grid.
@@ -202,39 +164,14 @@ def call(self, inputs):
202
164
"""
203
165
batch_dims = tf .shape (inputs )[:- 1 ]
204
166
205
- # Cut last samples that do not fit into an OFDM symbol
206
- x = inputs if self ._rest == 0 else inputs [..., :- self ._rest ]
207
-
208
- if self ._num_pad_symbols > 0 :
209
- pad_samples = self ._num_pad_symbols * (self .fft_size +
210
- self .cyclic_prefix_length )
211
- padding_shape = tf .concat ([batch_dims , [pad_samples ]], axis = 0 )
212
- padding = tf .zeros (padding_shape , dtype = x .dtype )
213
- x = tf .concat ([x , padding ], axis = - 1 )
214
-
215
- # Reshape input to blocks
216
- num_blocks = tf .shape (x )[- 1 ] // self ._samples_per_block
217
- new_shape = tf .concat ([batch_dims ,
218
- [num_blocks , self ._samples_per_block ]], 0 )
219
- x = tf .reshape (x , new_shape )
220
-
221
- # Remove extra cyclic prefix from first symbol
222
- x = x [...,(self .cyclic_prefix_length_first_symbol -
223
- self .cyclic_prefix_length ):]
167
+ # Gather samples that do not belong to cyclic prefix
168
+ x = tf .gather (inputs , self ._gather_idx , axis = - 1 )
224
169
225
170
# Reshape input to separate OFDM symbols
226
171
new_shape = tf .concat ([batch_dims ,
227
- [num_blocks * self .symbols_per_block ],
228
- [self .fft_size + self .cyclic_prefix_length ]], 0 )
172
+ [self ._num_ofdm_symbols , self .fft_size ]], 0 )
229
173
x = tf .reshape (x , new_shape )
230
174
231
- # Remove padding
232
- if self ._num_pad_symbols > 0 :
233
- x = x [..., :- self ._num_pad_symbols , :]
234
-
235
- # Remove cyclic prefix
236
- x = x [...,self .cyclic_prefix_length :]
237
-
238
175
# Compute FFT
239
176
x = fft (x )
240
177
0 commit comments