@@ -172,10 +172,10 @@ defmodule NxSignal do
172172
173173 * `:window_length` - the number of samples in a window
174174 * `:stride` - The number of samples to skip between windows. Defaults to `1`.
175- * `:padding` - A can be `:reflect` or a valid padding as per `Nx.pad/3` over the
176- input tensor's shape. Defaults to `:valid`. If `:reflect` or `:zeros `, the first window will be centered
177- at the start of the signal. For `:reflect`, each incomplete window will be reflected as if it was
178- periodic (see examples for `as_windowed/2`) . For `:zeros`, each incomplete window will be zero-padded.
175+ * `:padding` - Padding mode, can be `:reflect` or a valid padding as per `Nx.pad/3` over the
176+ input tensor's shape. Defaults to `:valid`. If `:reflect` or `:same `, the first window will be centered
177+ at the start of the signal. The padding is applied for the whole input, rather than individual
178+ windows . For `:zeros`, effectively each incomplete window will be zero-padded.
179179
180180 ## Examples
181181
@@ -219,27 +219,29 @@ defmodule NxSignal do
219219 iex> t = Nx.iota({7});
220220 iex> NxSignal.as_windowed(t, window_length: 6, padding: :reflect, stride: 1)
221221 #Nx.Tensor<
222- s64[7 ][6]
222+ s64[8 ][6]
223223 [
224- [1 , 2, 1, 0, 1, 2],
224+ [3 , 2, 1, 0, 1, 2],
225225 [2, 1, 0, 1, 2, 3],
226226 [1, 0, 1, 2, 3, 4],
227227 [0, 1, 2, 3, 4, 5],
228228 [1, 2, 3, 4, 5, 6],
229229 [2, 3, 4, 5, 6, 5],
230- [3, 4, 5, 6, 5, 4]
230+ [3, 4, 5, 6, 5, 4],
231+ [4, 5, 6, 5, 4, 3]
231232 ]
232233 >
233234
234235 iex> NxSignal.as_windowed(Nx.iota({10}), window_length: 6, padding: :reflect, stride: 2)
235236 #Nx.Tensor<
236- s64[5 ][6]
237+ s64[6 ][6]
237238 [
238- [1 , 2, 1, 0, 1, 2],
239+ [3 , 2, 1, 0, 1, 2],
239240 [1, 0, 1, 2, 3, 4],
240241 [1, 2, 3, 4, 5, 6],
241242 [3, 4, 5, 6, 7, 8],
242- [5, 6, 7, 8, 9, 8]
243+ [5, 6, 7, 8, 9, 8],
244+ [7, 8, 9, 8, 7, 6]
243245 ]
244246 >
245247 """
@@ -257,7 +259,7 @@ defmodule NxSignal do
257259
258260 as_windowed_parse_non_reflect_opts (
259261 shape ,
260- Keyword . put ( opts , :padding , [ { div ( window_length , 2 ) , div ( window_length , 2 ) - 1 } ] )
262+ Keyword . put ( opts , :padding , [ { div ( window_length , 2 ) , div ( window_length , 2 ) } ] )
261263 )
262264 end
263265
@@ -333,114 +335,34 @@ defmodule NxSignal do
333335 { window_length , stride , padding , output_shape } =
334336 as_windowed_parse_non_reflect_opts ( Nx . shape ( tensor ) , opts )
335337
336- output = Nx . broadcast ( Nx . tensor ( 0 , type: tensor . type ) , output_shape )
337- { num_windows , _ } = Nx . shape ( output )
338-
339- index_template =
340- Nx . concatenate ( [ Nx . broadcast ( 0 , { window_length , 1 } ) , Nx . iota ( { window_length , 1 } ) ] , axis: 1 )
341-
342- { output , _ , _ , _ , _ } =
343- while { output , i = 0 , current_window = 0 , t = Nx . pad ( tensor , 0 , padding ) , index_template } ,
344- current_window < num_windows do
345- indices = index_template + Nx . stack ( [ current_window , 0 ] )
346- updates = t |> Nx . slice ( [ i ] , [ window_length ] ) |> Nx . flatten ( )
347-
348- updated = Nx . indexed_add ( output , indices , updates )
338+ tensor = Nx . pad ( tensor , 0 , padding )
349339
350- { updated , i + stride , current_window + 1 , t , index_template }
351- end
352-
353- output
340+ as_windowed_apply ( tensor , stride , output_shape , window_length )
354341 end
355342
356343 defnp as_windowed_reflect_padding ( tensor , opts \\ [ ] ) do
357344 # current implementation only supports windowing 1D tensors
358345 { window_length , stride , _padding , output_shape } =
359346 as_windowed_parse_reflect_opts ( Nx . shape ( tensor ) , opts )
360347
361- output = Nx . broadcast ( Nx . tensor ( 0 , type: tensor . type ) , output_shape )
362- { num_windows , _ } = Nx . shape ( output )
363-
364- index_template =
365- Nx . concatenate ( [ Nx . broadcast ( 0 , { window_length , 1 } ) , Nx . iota ( { window_length , 1 } ) ] , axis: 1 )
366-
367- leading_window_indices = generate_leading_window_indices ( window_length , stride )
368-
369- trailing_window_indices =
370- generate_trailing_window_indices ( Nx . size ( tensor ) , window_length , stride )
371-
372- half_window = div ( window_length - 1 , 2 ) + 1
373-
374- { output , _ , _ , _ , _ } =
375- while { output , i = 0 , current_window = 0 , t = tensor , index_template } ,
376- current_window < num_windows do
377- # Here windows are centered at the current index
378-
379- cond do
380- i < half_window ->
381- # We're indexing before we have a full window on the left
382-
383- window = Nx . take ( t , leading_window_indices [ i ] )
384-
385- indices = index_template + Nx . stack ( [ current_window , 0 ] )
386- updated = Nx . indexed_add ( output , indices , window )
387-
388- { updated , i + stride , current_window + 1 , t , index_template }
389-
390- i > Nx . size ( t ) - half_window ->
391- # We're indexing after the last full window on the right
392- window = Nx . take ( t , trailing_window_indices [ i - ( Nx . size ( t ) - half_window + 1 ) ] )
393-
394- indices = index_template + Nx . stack ( [ current_window , 0 ] )
395- updated = Nx . indexed_add ( output , indices , window )
396-
397- { updated , i + stride , current_window + 1 , t , index_template }
398-
399- true ->
400- # Case where we can index a full window
401- indices = index_template + Nx . stack ( [ current_window , 0 ] )
402- updates = t |> Nx . slice ( [ i - half_window ] , [ window_length ] ) |> Nx . flatten ( )
403-
404- updated = Nx . indexed_add ( output , indices , updates )
405-
406- { updated , i + stride , current_window + 1 , t , index_template }
407- end
408- end
409-
410- # Now we need to handle the tail-end of the windows,
411- # since they are currently all the same value. We want to apply the tapering-off
412- # like we did with the initial windows.
413-
414- output
415- end
416-
417- deftransformp generate_leading_window_indices ( window_length , stride ) do
418348 half_window = div ( window_length , 2 )
349+ tensor = Nx . reflect ( tensor , padding_config: [ { half_window , half_window } ] )
419350
420- for offset <- 0 .. half_window // stride do
421- partial_length = offset + half_window
422- padding_length = window_length - partial_length
423-
424- { partial_length }
425- |> Nx . iota ( )
426- |> Nx . reflect ( padding_config: [ { padding_length , 0 } ] )
427- end
428- |> Nx . stack ( )
351+ as_windowed_apply ( tensor , stride , output_shape , window_length )
429352 end
430353
431- deftransformp generate_trailing_window_indices ( tensor_size , window_length , stride ) do
432- min_index = tensor_size - window_length + 1
354+ defnp as_windowed_apply ( tensor , stride , output_shape , window_length ) do
355+ output = Nx . broadcast ( Nx . tensor ( 0 , type: tensor . type ) , output_shape )
356+ { num_windows , _ } = Nx . shape ( output )
433357
434- for { offset , add } <- Enum . with_index ( min_index .. ( tensor_size - 1 ) // stride ) do
435- partial_length = tensor_size - offset
436- padding_length = window_length - partial_length
358+ { output , _ , _ , _ } =
359+ while { output , i = 0 , current_window = 0 , t = tensor } , current_window < num_windows do
360+ window = t |> Nx . slice ( [ i ] , [ window_length ] )
361+ updated = Nx . put_slice ( output , [ current_window , 0 ] , Nx . new_axis ( window , 0 ) )
362+ { updated , i + stride , current_window + 1 , t }
363+ end
437364
438- { partial_length }
439- |> Nx . iota ( )
440- |> Nx . add ( min_index + add - rem ( window_length , 2 ) )
441- |> Nx . reflect ( padding_config: [ { 0 , padding_length } ] )
442- end
443- |> Nx . stack ( )
365+ output
444366 end
445367
446368 @ doc """
@@ -548,15 +470,16 @@ defmodule NxSignal do
548470 iex> Nx.axis_size(z, :frequencies)
549471 16
550472 iex> Nx.axis_size(z, :frames)
551- 5
473+ 6
552474 iex> NxSignal.stft_to_mel(z, sampling_rate, fft_length: fft_length, mel_bins: 4)
553475 #Nx.Tensor<
554- f32[frames: 5 ][mel: 4]
476+ f32[frames: 6 ][mel: 4]
555477 [
556478 [0.2900530695915222, 0.17422175407409668, 0.18422472476959229, 0.09807997941970825],
557479 [0.6093881130218506, 0.5647397041320801, 0.4353824257850647, 0.08635270595550537],
558480 [0.7584103345870972, 0.7085014581680298, 0.5636920928955078, 0.179118812084198],
559481 [0.8461772203445435, 0.7952491044998169, 0.6470762491226196, 0.2520409822463989],
482+ [0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721],
560483 [0.908548891544342, 0.8572604656219482, 0.7078656554222107, 0.3086767792701721]
561484 ]
562485 >
0 commit comments