1+ import os
2+ import struct
3+ from io import BytesIO
4+
5+ EBML_HEADER = 0x1A45DFA3
6+ SEGMENT = 0x18538067
7+ SEEK_HEAD = 0x114D9B74
8+ INFO = 0x1549A966
9+ TAGS = 0x1254C367
10+ SEEK = 0x4DBB
11+ SEEK_ID = 0x53AB
12+ SEEK_POSITION = 0x53AC
13+ TITLE = 0x7BA9
14+ TAG = 0x7373
15+ SIMPLE_TAG = 0x67C8
16+ TAG_NAME = 0x45A3
17+ TAG_STRING = 0x4487
18+ TAG_BINARY = 0x4485
19+
20+
21+ def _read_vint_raw (stream ):
22+ first_byte_data = stream .read (1 )
23+ if not first_byte_data :
24+ raise IOError ("Reached end of stream while reading VINT" )
25+ first_byte = first_byte_data [0 ]
26+
27+ width = 0
28+ for i in range (8 ):
29+ if (first_byte >> (7 - i )) & 1 :
30+ width = i + 1
31+ break
32+ else :
33+ raise ValueError ("Invalid VINT start byte" )
34+
35+ data = bytes ([first_byte ]) + stream .read (width - 1 )
36+ if len (data ) < width :
37+ raise IOError ("Reached end of stream while reading VINT data" )
38+
39+ value = int .from_bytes (data , 'big' )
40+ return value , width
41+
42+
43+ def _read_vint_value (stream ):
44+ first_byte_data = stream .read (1 )
45+ if not first_byte_data :
46+ raise IOError ("Reached end of stream while reading VINT" )
47+ first_byte = first_byte_data [0 ]
48+
49+ width = 0
50+ for i in range (8 ):
51+ if (first_byte >> (7 - i )) & 1 :
52+ width = i + 1
53+ break
54+ else :
55+ raise ValueError ("Invalid VINT start byte" )
56+
57+ data = bytes ([first_byte ]) + stream .read (width - 1 )
58+ if len (data ) < width :
59+ raise IOError ("Reached end of stream while reading VINT data" )
60+
61+ value = data [0 ] & ((1 << (8 - width )) - 1 )
62+ for i in range (1 , width ):
63+ value = (value << 8 ) | data [i ]
64+
65+ return value , width
66+
67+
68+ def _write_vint_size (value ):
69+ if value < (2 ** 7 ) - 1 :
70+ return (value | 0x80 ).to_bytes (1 , 'big' )
71+ elif value < (2 ** 14 ) - 1 :
72+ return (value | 0x4000 ).to_bytes (2 , 'big' )
73+ elif value < (2 ** 21 ) - 1 :
74+ return (value | 0x200000 ).to_bytes (3 , 'big' )
75+ elif value < (2 ** 28 ) - 1 :
76+ return (value | 0x10000000 ).to_bytes (4 , 'big' )
77+ elif value < (2 ** 35 ) - 1 :
78+ return (value | 0x0800000000 ).to_bytes (5 , 'big' )
79+ elif value < (2 ** 42 ) - 1 :
80+ return (value | 0x040000000000 ).to_bytes (6 , 'big' )
81+ elif value < (2 ** 49 ) - 1 :
82+ return (value | 0x02000000000000 ).to_bytes (7 , 'big' )
83+ elif value < (2 ** 56 ) - 1 :
84+ return (value | 0x0100000000000000 ).to_bytes (8 , 'big' )
85+ else :
86+ raise ValueError ("VINT size too large" )
87+
88+
89+ def _write_element (element_id , data ):
90+ id_bytes = element_id .to_bytes (4 , 'big' )
91+ if element_id < 0x1000000 :
92+ id_bytes = id_bytes [1 :]
93+ if element_id < 0x10000 :
94+ id_bytes = id_bytes [1 :]
95+ if element_id < 0x100 :
96+ id_bytes = id_bytes [1 :]
97+ return id_bytes + _write_vint_size (len (data )) + data
98+
99+
100+ def _iter_elements (stream , end_pos = None ):
101+ while end_pos is None or stream .tell () < end_pos :
102+ current_pos = stream .tell ()
103+ try :
104+ element_id , id_len = _read_vint_raw (stream )
105+ size , size_len = _read_vint_value (stream )
106+ except (IOError , ValueError ):
107+ break
108+
109+ data_pos = stream .tell ()
110+ if end_pos is not None and data_pos + size > end_pos :
111+ break
112+
113+ yield element_id , stream .read (size )
114+ stream .seek (current_pos + id_len + size_len + size )
115+
116+
117+ class MKVTags (dict ):
118+ def __setitem__ (self , key , value ):
119+ if not isinstance (value , list ):
120+ value = [value ]
121+ super ().__setitem__ (key .upper (), value )
122+
123+ def __getitem__ (self , key ):
124+ return super ().__getitem__ (key .upper ())
125+
126+ def add_tag (self , key , value ):
127+ key = key .upper ()
128+ if key in self :
129+ super ().__getitem__ (key ).append (value )
130+ else :
131+ # Creates the initial list
132+ super ().__setitem__ (key , [value ])
133+
134+
135+ class MKVFile :
136+ def __init__ (self , filename ):
137+ self .filename = filename
138+ self .tags = MKVTags ()
139+ self ._load ()
140+
141+ def _load (self ):
142+ with open (self .filename , 'rb' ) as f :
143+ header_id_bytes = f .read (4 )
144+ if not header_id_bytes or int .from_bytes (header_id_bytes , 'big' ) != EBML_HEADER :
145+ return
146+
147+ try :
148+ header_size , _ = _read_vint_value (f )
149+ f .seek (header_size , 1 )
150+ except (IOError , ValueError ):
151+ return
152+
153+ try :
154+ segment_pos = f .tell ()
155+ segment_id_val , id_len = _read_vint_raw (f )
156+ segment_size , size_len = _read_vint_value (f )
157+ if segment_id_val != SEGMENT :
158+ return
159+ segment_data_pos = f .tell ()
160+ except (IOError , ValueError ):
161+ return
162+
163+ tags_pos , info_pos = None , None
164+ f .seek (segment_data_pos )
165+ for eid , edata in _iter_elements (f , end_pos = segment_data_pos + 4096 ):
166+ if eid == SEEK_HEAD :
167+ for seek_id , seek_data in _iter_elements (BytesIO (edata )):
168+ if seek_id == SEEK :
169+ s_id , s_pos = None , None
170+ for sub_id , sub_data in _iter_elements (BytesIO (seek_data )):
171+ if sub_id == SEEK_ID : s_id = int .from_bytes (sub_data , 'big' )
172+ elif sub_id == SEEK_POSITION : s_pos = int .from_bytes (sub_data , 'big' )
173+ if s_id == TAGS : tags_pos = segment_data_pos + s_pos
174+ elif s_id == INFO : info_pos = segment_data_pos + s_pos
175+ break
176+
177+ if info_pos is not None :
178+ f .seek (info_pos )
179+ try :
180+ info_id , info_data = next (_iter_elements (f ))
181+ if info_id == INFO : self ._parse_info (info_data )
182+ else : info_pos = None
183+ except (StopIteration , IOError , ValueError ): info_pos = None
184+
185+ tags_data = None
186+ if tags_pos is not None :
187+ f .seek (tags_pos )
188+ try :
189+ tags_id , tags_data_read = next (_iter_elements (f ))
190+ if tags_id == TAGS : tags_data = tags_data_read
191+ except (StopIteration , IOError , ValueError ): pass
192+ if tags_data is None or info_pos is None :
193+ f .seek (segment_data_pos )
194+ scan_end_pos = min (segment_data_pos + segment_size , segment_data_pos + 10 * 1024 * 1024 )
195+ for eid , edata in _iter_elements (f , scan_end_pos ):
196+ if eid == TAGS and tags_data is None : tags_data = edata
197+ elif eid == INFO and info_pos is None :
198+ self ._parse_info (edata )
199+ info_pos = True
200+ if tags_data is not None and info_pos is not None : break
201+
202+ if tags_data :
203+ self ._parse_tags_element (tags_data )
204+
205+ def _parse_info (self , info_data ):
206+ for eid , edata in _iter_elements (BytesIO (info_data )):
207+ if eid == TITLE :
208+ self .tags .add_tag ('TITLE' , edata .decode ('utf-8' , 'replace' ))
209+ break
210+
211+ def _parse_tags_element (self , tags_data ):
212+ for eid , edata in _iter_elements (BytesIO (tags_data )):
213+ if eid == TAG :
214+ self ._parse_tag (edata )
215+
216+ def _parse_tag (self , tag_data ):
217+ for eid , edata in _iter_elements (BytesIO (tag_data )):
218+ if eid == SIMPLE_TAG :
219+ self ._parse_simple_tag_recursive (edata )
220+
221+ def _parse_simple_tag_recursive (self , simple_tag_data ):
222+ tag_name , tag_value = None , None
223+ nested_tags_data = []
224+
225+ for eid , edata in _iter_elements (BytesIO (simple_tag_data )):
226+ if eid == TAG_NAME : tag_name = edata .decode ('utf-8' , 'replace' )
227+ elif eid == TAG_STRING : tag_value = edata .decode ('utf-8' , 'replace' )
228+ elif eid == TAG_BINARY : tag_value = edata
229+ elif eid == SIMPLE_TAG : nested_tags_data .append (edata )
230+
231+ if tag_name and tag_value is not None :
232+ self .tags .add_tag (tag_name , tag_value )
233+
234+ for nested_data in nested_tags_data :
235+ self ._parse_simple_tag_recursive (nested_data )
236+
237+ def add_tags (self ):
238+ if not self .tags :
239+ self .tags = MKVTags ()
240+
241+ def delete (self , filename = None ):
242+ if filename is None :
243+ filename = self .filename
244+ self .tags .clear ()
245+ self .save (filename , delete_tags = True )
246+
247+ def _render_tags (self ):
248+ tags_payload = b""
249+ if self .tags :
250+ for key , values in sorted (self .tags .items ()):
251+ for v in values :
252+ simple_tag_payload = b""
253+ simple_tag_payload += _write_element (TAG_NAME , key .encode ('utf-8' ))
254+ if isinstance (v , str ):
255+ simple_tag_payload += _write_element (TAG_STRING , v .encode ('utf-8' ))
256+ elif isinstance (v , bytes ):
257+ simple_tag_payload += _write_element (TAG_BINARY , v )
258+ else :
259+ continue
260+ tags_payload += _write_element (TAG , _write_element (SIMPLE_TAG , simple_tag_payload ))
261+ return _write_element (TAGS , tags_payload ) if tags_payload else b""
262+
263+ def save (self , filename = None , delete_tags = False ):
264+ if filename is None :
265+ filename = self .filename
266+
267+ temp_filename = filename + ".tmp"
268+
269+ with open (self .filename , 'rb' ) as f_in , open (temp_filename , 'wb' ) as f_out :
270+ f_in .seek (0 )
271+ ebml_header_end = 0
272+ try :
273+ eid , id_len = _read_vint_raw (f_in )
274+ if eid != EBML_HEADER : raise ValueError ("No EBML Header" )
275+ size , size_len = _read_vint_value (f_in )
276+ ebml_header_end = f_in .tell () + size
277+ f_in .seek (0 )
278+ f_out .write (f_in .read (ebml_header_end ))
279+ except (IOError , ValueError ) as e :
280+ raise IOError (f"Cannot read MKV file structure: { e } " )
281+ f_in .seek (ebml_header_end )
282+ segment_data_start_pos = 0
283+ try :
284+ eid , id_len = _read_vint_raw (f_in )
285+ if eid != SEGMENT : raise ValueError ("No Segment found after EBML Header" )
286+ f_out .write (eid .to_bytes (4 , 'big' ))
287+ f_out .write (b'\x01 \xFF \xFF \xFF \xFF \xFF \xFF \xFF ' )
288+ _ , size_len = _read_vint_value (f_in )
289+ segment_data_start_pos = f_in .tell ()
290+ except (IOError , ValueError ):
291+ raise IOError ("Could not find or parse Segment element" )
292+ new_tags_element = b""
293+ if self .tags and not delete_tags :
294+ new_tags_element = self ._render_tags ()
295+
296+ tags_element_start_pos = - 1
297+ tags_element_len = 0
298+ insert_pos = - 1
299+
300+ f_in .seek (segment_data_start_pos )
301+ try :
302+ current_pos = f_in .tell ()
303+ eid , id_len = _read_vint_raw (f_in )
304+ size , size_len = _read_vint_value (f_in )
305+ insert_pos = current_pos + id_len + size_len + size
306+
307+ f_in .seek (segment_data_start_pos )
308+ while True :
309+ element_start = f_in .tell ()
310+ try :
311+ eid , id_len = _read_vint_raw (f_in )
312+ size , size_len = _read_vint_value (f_in )
313+ except (IOError , ValueError ): break
314+
315+ if eid == TAGS :
316+ tags_element_start_pos = element_start
317+ tags_element_len = id_len + size_len + size
318+ break
319+ f_in .seek (element_start + id_len + size_len + size )
320+ except (IOError , ValueError ):
321+ insert_pos = segment_data_start_pos
322+ f_in .seek (segment_data_start_pos )
323+ if tags_element_start_pos != - 1 :
324+ bytes_before = tags_element_start_pos - segment_data_start_pos
325+ f_out .write (f_in .read (bytes_before ))
326+ f_out .write (new_tags_element )
327+ f_in .seek (tags_element_start_pos + tags_element_len )
328+ f_out .write (f_in .read ())
329+ elif insert_pos != - 1 :
330+ bytes_before = insert_pos - segment_data_start_pos
331+ f_out .write (f_in .read (bytes_before ))
332+ f_out .write (new_tags_element )
333+ f_out .write (f_in .read ())
334+ else :
335+ f_out .write (new_tags_element )
336+ f_out .write (f_in .read ())
337+
338+ os .replace (temp_filename , filename )
0 commit comments