-
Notifications
You must be signed in to change notification settings - Fork 1
/
TorchTimeStretch.py
128 lines (114 loc) · 4.11 KB
/
TorchTimeStretch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Code Source: https://github.com/KentoNishi/torch-time-stretch
# Thanks to the great work of KentoNishi(https://github.com/KentoNishi)
from collections import Counter
from fractions import Fraction
from functools import reduce
from itertools import chain, count, islice, repeat
from typing import Union, Callable, List, Optional
from torch.nn.functional import pad
import torch
import torchaudio.transforms as T
from primePy import primes
from math import log2
import warnings
warnings.simplefilter("ignore")
# https://stackoverflow.com/a/46623112/9325832
def _combinations_without_repetition(r, iterable=None, values=None, counts=None):
if iterable:
values, counts = zip(*Counter(iterable).items())
f = lambda i, c: chain.from_iterable(map(repeat, i, c))
n = len(counts)
indices = list(islice(f(count(), counts), r))
if len(indices) < r:
return
while True:
yield tuple(values[i] for i in indices)
for i, j in zip(reversed(range(r)), f(reversed(range(n)), reversed(counts))):
if indices[i] != j:
break
else:
return
j = indices[i] + 1
for i, j in zip(range(i, r), f(count(j), counts[j:])):
indices[i] = j
def get_fast_stretches(
sample_rate: int,
condition: Optional[Callable] = lambda x: x >= 0.5 and x <= 2 and x != 1,
) -> List[Fraction]:
"""
Search for time-stretch targets that can be computed quickly for a given sample rate.
Parameters
----------
sample_rate: int
The sample rate of an audio clip.
condition: Callable [optional]
A function to validate fast stretch ratios.
Default is `lambda x: x >= 0.5 and x <= 2 and x != 1` (between 50% and 200% speed).
Returns
-------
output: List[Fraction]
A list of fast time-stretch target ratios
"""
fast_shifts = set()
factors = primes.factors(sample_rate)
products = []
for i in range(1, len(factors) + 1):
products.extend(
[
reduce(lambda x, y: x * y, x)
for x in _combinations_without_repetition(i, iterable=factors)
]
)
for i in products:
for j in products:
f = Fraction(i, j)
if condition(f):
fast_shifts.add(f)
return list(fast_shifts)
def time_stretch(
input: torch.Tensor,
stretch: Union[float, Fraction],
sample_rate: int,
n_fft: Optional[int] = 0,
hop_length: Optional[int] = 0,
) -> torch.Tensor:
"""
Stretch a batch of waveforms by a given amount without altering the pitch.
Modified so that the function takes input[shape=(batch_size, samples)]
Parameters
----------
input: torch.Tensor [shape=(batch_size, samples) or shape=(samples)]
Input audio clips of shape (batch_size, samples) or (samples)
stretch: float OR Fraction
Indicates the stretch ratio. Usually an element in `get_fast_stretches()`.
sample_rate: int
The sample rate of the input audio clips.
n_fft: int [optional]
Size of FFT. Default is `sample_rate // 64`.
hop_length: int [optional]
Size of hop length. Default is `n_fft // 32`.
Returns
-------
output: torch.Tensor [shape=(batch_size, samples)]
The time-stretched batch of audio clips
"""
if len(input.shape) == 1:
input = input.unsqueeze(0)
if not n_fft:
n_fft = sample_rate // 64
if not hop_length:
hop_length = n_fft // 32
# batch_size, samples = input.shape
# resampler = T.Resample(sample_rate, int(sample_rate / stretch)).to(input.device)
output = input
# output = output.reshape(batch_size * channels, samples)
output = torch.stft(output, n_fft, hop_length, return_complex=True)[None, ...]
stretcher = T.TimeStretch(
fixed_rate=float(1 / stretch), n_freq=output.shape[2], hop_length=hop_length
).to(input.device)
output = stretcher(output)
output = torch.istft(output[0], n_fft, hop_length)
# output = resampler(output)
del stretcher # , resampler
# output = output.reshape(batch_size, channels, output.shape[1])
return output