-
Notifications
You must be signed in to change notification settings - Fork 42
Expand file tree
/
Copy pathtest_video_only.py
More file actions
75 lines (56 loc) · 2.5 KB
/
Copy pathtest_video_only.py
File metadata and controls
75 lines (56 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
Test Video Only - Simple video-based audio separation using SAM Audio
Uses small model with bfloat16 for lower memory usage
"""
import torch
import torchaudio
import gc
def main():
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
print(f"Device: {device}, dtype: {dtype}")
# Clear GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
print(f"GPU Memory before loading: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
# Load model (small for lower memory)
from sam_audio import SAMAudio, SAMAudioProcessor
model_name = "facebook/sam-audio-base"
print(f"Loading {model_name}...")
model = SAMAudio.from_pretrained(model_name).to(device, dtype).eval()
processor = SAMAudioProcessor.from_pretrained(model_name)
if torch.cuda.is_available():
print(f"GPU Memory after loading: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
# Video file
video_file = "office.mp4"
description = "walking sound"
print(f"\nProcessing video: {video_file}")
print(f"Description: '{description}'")
# Process
inputs = processor(audios=[video_file], descriptions=[description]).to(device)
print("Running separation...")
if torch.cuda.is_available():
print(f"GPU Memory before separation: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
with torch.inference_mode(), torch.autocast(device_type=device.type, dtype=dtype):
result = model.separate(inputs)
if torch.cuda.is_available():
print(f"GPU Memory after separation: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
# Save results
sample_rate = processor.audio_sampling_rate
target_audio = result.target[0].float().unsqueeze(0).cpu()
residual_audio = result.residual[0].float().unsqueeze(0).cpu()
torchaudio.save("video_target.wav", target_audio, sample_rate)
torchaudio.save("video_residual.wav", residual_audio, sample_rate)
print("\nDone!")
print("- video_target.wav: Extracted audio (target)")
print("- video_residual.wav: Remaining audio (residual)")
# Cleanup
del model, processor, inputs, result
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
print(f"GPU Memory after cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
if __name__ == "__main__":
main()