From 324ad82006748ebfe4b3fa8f67f160eb000ee6eb Mon Sep 17 00:00:00 2001 From: pukkandan Date: Tue, 8 Jun 2021 14:23:56 +0530 Subject: [PATCH] [utils] Generalize `traverse_dict` to `traverse_obj` --- yt_dlp/YoutubeDL.py | 6 ++--- yt_dlp/postprocessor/ffmpeg.py | 4 ++-- yt_dlp/utils.py | 41 ++++++++++++++++++++++++---------- 3 files changed, 34 insertions(+), 17 deletions(-) diff --git a/yt_dlp/YoutubeDL.py b/yt_dlp/YoutubeDL.py index 2997b19ca3..1643649fba 100644 --- a/yt_dlp/YoutubeDL.py +++ b/yt_dlp/YoutubeDL.py @@ -101,7 +101,7 @@ strftime_or_none, subtitles_filename, to_high_limit_path, - traverse_dict, + traverse_obj, UnavailableVideoError, url_basename, version_tuple, @@ -855,7 +855,7 @@ def prepare_outtmpl(self, outtmpl, info_dict, sanitize=None): def get_value(mdict): # Object traversal fields = mdict['fields'].split('.') - value = traverse_dict(info_dict, fields) + value = traverse_obj(info_dict, fields) # Negative if mdict['negate']: value = float_or_none(value) @@ -872,7 +872,7 @@ def get_value(mdict): item, multiplier = (item[1:], -1) if item[0] == '-' else (item, 1) offset = float_or_none(item) if offset is None: - offset = float_or_none(traverse_dict(info_dict, item.split('.'))) + offset = float_or_none(traverse_obj(info_dict, item.split('.'))) try: value = operator(value, multiplier * offset) except (TypeError, ZeroDivisionError): diff --git a/yt_dlp/postprocessor/ffmpeg.py b/yt_dlp/postprocessor/ffmpeg.py index d9f816b043..374da8c02b 100644 --- a/yt_dlp/postprocessor/ffmpeg.py +++ b/yt_dlp/postprocessor/ffmpeg.py @@ -23,7 +23,7 @@ ISO639Utils, process_communicate_or_kill, replace_extension, - traverse_dict, + traverse_obj, ) @@ -229,7 +229,7 @@ def get_metadata_object(self, path, opts=[]): def get_stream_number(self, path, keys, value): streams = self.get_metadata_object(path)['streams'] num = next( - (i for i, stream in enumerate(streams) if traverse_dict(stream, keys, casesense=False) == value), + (i for i, stream in enumerate(streams) if traverse_obj(stream, keys, casesense=False) == value), None) return num, len(streams) diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index 72fd8a0e7d..6737c1965e 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -6181,21 +6181,38 @@ def load_plugins(name, suffix, namespace): return classes -def traverse_dict(dictn, keys, casesense=True): +def traverse_obj(obj, keys, *, casesense=True, is_user_input=False, traverse_string=False): + ''' Traverse nested list/dict/tuple + @param casesense Whether to consider dictionary keys as case sensitive + @param is_user_input Whether the keys are generated from user input. If True, + strings are converted to int/slice if necessary + @param traverse_string Whether to traverse inside strings. If True, any + non-compatible object will also be converted into a string + ''' keys = list(keys)[::-1] while keys: key = keys.pop() - if isinstance(dictn, dict): + if isinstance(obj, dict): + assert isinstance(key, compat_str) if not casesense: - dictn = {k.lower(): v for k, v in dictn.items()} + obj = {k.lower(): v for k, v in obj.items()} key = key.lower() - dictn = dictn.get(key) - elif isinstance(dictn, (list, tuple, compat_str)): - if ':' in key: - key = slice(*map(int_or_none, key.split(':'))) - else: - key = int_or_none(key) - dictn = try_get(dictn, lambda x: x[key]) + obj = obj.get(key) else: - return None - return dictn + if is_user_input: + key = (int_or_none(key) if ':' not in key + else slice(*map(int_or_none, key.split(':')))) + if not isinstance(obj, (list, tuple)): + if traverse_string: + obj = compat_str(obj) + else: + return None + assert isinstance(key, (int, slice)) + obj = try_get(obj, lambda x: x[key]) + return obj + + +def traverse_dict(dictn, keys, casesense=True): + ''' For backward compatibility. Do not use ''' + return traverse_obj(dictn, keys, casesense=casesense, + is_user_input=True, traverse_string=True)