From 6839ae1f6dde4c0442619e351b3f0442312ab4f9 Mon Sep 17 00:00:00 2001 From: Simon Sawicki Date: Fri, 10 Feb 2023 03:56:26 +0530 Subject: [PATCH] [utils] `traverse_obj`: Fix more bugs and cleanup uses of `default=[]` Continued from b1bde57bef878478e3503ab07190fd207914ade9 --- test/test_utils.py | 75 ++++++++++++++++++++++------------- yt_dlp/downloader/fragment.py | 2 +- yt_dlp/extractor/abematv.py | 4 +- yt_dlp/extractor/gamejolt.py | 2 +- yt_dlp/extractor/iqiyi.py | 8 ++-- yt_dlp/extractor/panopto.py | 4 +- yt_dlp/extractor/patreon.py | 2 +- yt_dlp/extractor/tiktok.py | 4 +- yt_dlp/extractor/youtube.py | 35 ++++++++-------- yt_dlp/utils.py | 15 ++++--- 10 files changed, 84 insertions(+), 67 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 190e4ef9b..3045b6d7e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2000,7 +2000,7 @@ def test_traverse_obj(self): # Test Ellipsis behavior self.assertCountEqual(traverse_obj(_TEST_DATA, ...), - (item for item in _TEST_DATA.values() if item not in (None, [], {})), + (item for item in _TEST_DATA.values() if item not in (None, {})), msg='`...` should give all non discarded values') self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, ...)), _TEST_DATA['urls'][0].values(), msg='`...` selection for dicts should select all values') @@ -2095,7 +2095,7 @@ def test_traverse_obj(self): msg='remove empty values when nested dict key fails') self.assertEqual(traverse_obj(None, {0: 'fail'}), {}, msg='default to dict if pruned') - self.assertEqual(traverse_obj(None, {0: 'fail'}, default=...), {}, + self.assertEqual(traverse_obj(None, {0: 'fail'}, default=...), {0: ...}, msg='default to dict if pruned and default is given') self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=...), {0: {0: ...}}, msg='use nested `default` when nested dict key fails and `default`') @@ -2124,34 +2124,55 @@ def test_traverse_obj(self): msg='if branched but not successful return `[]`, not `default`') self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', ...)), [], msg='if branched but object is empty return `[]`, not `default`') + self.assertEqual(traverse_obj(None, ...), [], + msg='if branched but object is `None` return `[]`, not `default`') + self.assertEqual(traverse_obj({0: None}, (0, ...)), [], + msg='if branched but state is `None` return `[]`, not `default`') + + branching_paths = [ + ('fail', ...), + (..., 'fail'), + 100 * ('fail',) + (...,), + (...,) + 100 * ('fail',), + ] + for branching_path in branching_paths: + self.assertEqual(traverse_obj({}, branching_path), [], + msg='if branched but state is `None`, return `[]` (not `default`)') + self.assertEqual(traverse_obj({}, 'fail', branching_path), [], + msg='if branching in last alternative and previous did not match, return `[]` (not `default`)') + self.assertEqual(traverse_obj({0: 'x'}, 0, branching_path), 'x', + msg='if branching in last alternative and previous did match, return single value') + self.assertEqual(traverse_obj({0: 'x'}, branching_path, 0), 'x', + msg='if branching in first alternative and non-branching path does match, return single value') + self.assertEqual(traverse_obj({}, branching_path, 'fail'), None, + msg='if branching in first alternative and non-branching path does not match, return `default`') # Testing expected_type behavior _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0} - self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str), 'str', - msg='accept matching `expected_type` type') - self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None, - msg='reject non matching `expected_type` type') - self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)), '0', - msg='transform type using type function') - self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', - expected_type=lambda _: 1 / 0), None, - msg='wrap expected_type fuction in try_call') - self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'], - msg='eliminate items that expected_type fails on') - self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int), {0: 100}, - msg='type as expected_type should filter dict values') - self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none), {0: '100', 1: '1.2'}, - msg='function as expected_type should transform dict values') - self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int), 1, - msg='expected_type should not filter non final dict values') - self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int), {0: {0: 100}}, - msg='expected_type should transform deep dict values') - self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(...)), [{0: ...}, {0: ...}], - msg='expected_type should transform branched dict values') - self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int), [4], - msg='expected_type regression for type matching in tuple branching') - self.assertEqual(traverse_obj(_TEST_DATA, ['data', ...], expected_type=int), [], - msg='expected_type regression for type matching in dict result') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str), + 'str', msg='accept matching `expected_type` type') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), + None, msg='reject non matching `expected_type` type') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)), + '0', msg='transform type using type function') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0), + None, msg='wrap expected_type fuction in try_call') + self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), + ['str'], msg='eliminate items that expected_type fails on') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int), + {0: 100}, msg='type as expected_type should filter dict values') + self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none), + {0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values') + self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int), + 1, msg='expected_type should not filter non final dict values') + self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int), + {0: {0: 100}}, msg='expected_type should transform deep dict values') + self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(...)), + [{0: ...}, {0: ...}], msg='expected_type should transform branched dict values') + self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int), + [4], msg='expected_type regression for type matching in tuple branching') + self.assertEqual(traverse_obj(_TEST_DATA, ['data', ...], expected_type=int), + [], msg='expected_type regression for type matching in dict result') # Test get_all behavior _GET_ALL_DATA = {'key': [0, 1, 2]} diff --git a/yt_dlp/downloader/fragment.py b/yt_dlp/downloader/fragment.py index 02f8559cc..039cb1492 100644 --- a/yt_dlp/downloader/fragment.py +++ b/yt_dlp/downloader/fragment.py @@ -383,7 +383,7 @@ def download_and_append_fragments_multiple(self, *args, **kwargs): max_workers = self.params.get('concurrent_fragment_downloads', 1) if max_progress > 1: self._prepare_multiline_status(max_progress) - is_live = any(traverse_obj(args, (..., 2, 'is_live'), default=[])) + is_live = any(traverse_obj(args, (..., 2, 'is_live'))) def thread_func(idx, ctx, fragments, info_dict, tpe): ctx['max_progress'] = max_progress diff --git a/yt_dlp/extractor/abematv.py b/yt_dlp/extractor/abematv.py index 9955fb289..7552e3e57 100644 --- a/yt_dlp/extractor/abematv.py +++ b/yt_dlp/extractor/abematv.py @@ -416,7 +416,7 @@ def _real_extract(self, url): f'https://api.abema.io/v1/video/programs/{video_id}', video_id, note='Checking playability', headers=headers) - ondemand_types = traverse_obj(api_response, ('terms', ..., 'onDemandType'), default=[]) + ondemand_types = traverse_obj(api_response, ('terms', ..., 'onDemandType')) if 3 not in ondemand_types: # cannot acquire decryption key for these streams self.report_warning('This is a premium-only stream') @@ -489,7 +489,7 @@ def _fetch_page(self, playlist_id, series_version, page): }) yield from ( self.url_result(f'https://abema.tv/video/episode/{x}') - for x in traverse_obj(programs, ('programs', ..., 'id'), default=[])) + for x in traverse_obj(programs, ('programs', ..., 'id'))) def _entries(self, playlist_id, series_version): return OnDemandPagedList( diff --git a/yt_dlp/extractor/gamejolt.py b/yt_dlp/extractor/gamejolt.py index 440b832fc..8ec046bb3 100644 --- a/yt_dlp/extractor/gamejolt.py +++ b/yt_dlp/extractor/gamejolt.py @@ -48,7 +48,7 @@ def _get_comments(self, post_num_id, post_hash_id): post_hash_id, note='Downloading comments list page %d' % page) if not comments_data.get('comments'): break - for comment in traverse_obj(comments_data, (('comments', 'childComments'), ...), expected_type=dict, default=[]): + for comment in traverse_obj(comments_data, (('comments', 'childComments'), ...), expected_type=dict): yield { 'id': comment['id'], 'text': self._parse_content_as_text( diff --git a/yt_dlp/extractor/iqiyi.py b/yt_dlp/extractor/iqiyi.py index eba89f787..4443b1991 100644 --- a/yt_dlp/extractor/iqiyi.py +++ b/yt_dlp/extractor/iqiyi.py @@ -585,7 +585,7 @@ def _real_extract(self, url): 'langCode': self._get_cookie('lang', 'en_us'), 'deviceId': self._get_cookie('QC005', '') }, fatal=False) - ut_list = traverse_obj(vip_data, ('data', 'all_vip', ..., 'vipType'), expected_type=str_or_none, default=[]) + ut_list = traverse_obj(vip_data, ('data', 'all_vip', ..., 'vipType'), expected_type=str_or_none) else: ut_list = ['0'] @@ -617,7 +617,7 @@ def _real_extract(self, url): self.report_warning('This preview video is limited%s' % format_field(preview_time, None, ' to %s seconds')) # TODO: Extract audio-only formats - for bid in set(traverse_obj(initial_format_data, ('program', 'video', ..., 'bid'), expected_type=str_or_none, default=[])): + for bid in set(traverse_obj(initial_format_data, ('program', 'video', ..., 'bid'), expected_type=str_or_none)): dash_path = dash_paths.get(bid) if not dash_path: self.report_warning(f'Unknown format id: {bid}. It is currently not being extracted') @@ -628,7 +628,7 @@ def _real_extract(self, url): fatal=False), 'data', expected_type=dict) video_format = traverse_obj(format_data, ('program', 'video', lambda _, v: str(v['bid']) == bid), - expected_type=dict, default=[], get_all=False) or {} + expected_type=dict, get_all=False) or {} extracted_formats = [] if video_format.get('m3u8Url'): extracted_formats.extend(self._extract_m3u8_formats( @@ -669,7 +669,7 @@ def _real_extract(self, url): }) formats.extend(extracted_formats) - for sub_format in traverse_obj(initial_format_data, ('program', 'stl', ...), expected_type=dict, default=[]): + for sub_format in traverse_obj(initial_format_data, ('program', 'stl', ...), expected_type=dict): lang = self._LID_TAGS.get(str_or_none(sub_format.get('lid')), sub_format.get('_name')) subtitles.setdefault(lang, []).extend([{ 'ext': format_ext, diff --git a/yt_dlp/extractor/panopto.py b/yt_dlp/extractor/panopto.py index 32c103bc1..6e3c9f442 100644 --- a/yt_dlp/extractor/panopto.py +++ b/yt_dlp/extractor/panopto.py @@ -412,7 +412,7 @@ def _real_extract(self, url): return { 'id': video_id, 'title': delivery.get('SessionName'), - 'cast': traverse_obj(delivery, ('Contributors', ..., 'DisplayName'), default=[], expected_type=lambda x: x or None), + 'cast': traverse_obj(delivery, ('Contributors', ..., 'DisplayName'), expected_type=lambda x: x or None), 'timestamp': session_start_time - 11640000000 if session_start_time else None, 'duration': delivery.get('Duration'), 'thumbnail': base_url + f'/Services/FrameGrabber.svc/FrameRedirect?objectId={video_id}&mode=Delivery&random={random()}', @@ -563,7 +563,7 @@ def _extract_folder_metadata(self, base_url, folder_id): base_url, '/Services/Data.svc/GetFolderInfo', folder_id, data={'folderID': folder_id}, fatal=False) return { - 'title': get_first(response, 'Name', default=[]) + 'title': get_first(response, 'Name') } def _real_extract(self, url): diff --git a/yt_dlp/extractor/patreon.py b/yt_dlp/extractor/patreon.py index 529aba178..e93e37eb9 100644 --- a/yt_dlp/extractor/patreon.py +++ b/yt_dlp/extractor/patreon.py @@ -310,7 +310,7 @@ def _get_comments(self, post_id): f'posts/{post_id}/comments', post_id, query=params, note='Downloading comments page %d' % page) cursor = None - for comment in traverse_obj(response, (('data', ('included', lambda _, v: v['type'] == 'comment')), ...), default=[]): + for comment in traverse_obj(response, (('data', ('included', lambda _, v: v['type'] == 'comment')), ...)): count += 1 comment_id = comment.get('id') attributes = comment.get('attributes') or {} diff --git a/yt_dlp/extractor/tiktok.py b/yt_dlp/extractor/tiktok.py index cc96de364..096748bf7 100644 --- a/yt_dlp/extractor/tiktok.py +++ b/yt_dlp/extractor/tiktok.py @@ -285,7 +285,7 @@ def extract_addr(addr, add_meta={}): user_url = self._UPLOADER_URL_FORMAT % (traverse_obj(author_info, 'sec_uid', 'id', 'uid', 'unique_id', expected_type=str_or_none, get_all=False)) - labels = traverse_obj(aweme_detail, ('hybrid_label', ..., 'text'), expected_type=str, default=[]) + labels = traverse_obj(aweme_detail, ('hybrid_label', ..., 'text'), expected_type=str) contained_music_track = traverse_obj( music_info, ('matched_song', 'title'), ('matched_pgc_sound', 'title'), expected_type=str) @@ -355,7 +355,7 @@ def _parse_aweme_video_web(self, aweme_detail, webpage_url): 'ext': 'mp4', 'width': width, 'height': height, - } for url in traverse_obj(play_url, (..., 'src'), expected_type=url_or_none, default=[]) if url] + } for url in traverse_obj(play_url, (..., 'src'), expected_type=url_or_none) if url] download_url = url_or_none(video_info.get('downloadAddr')) or traverse_obj(video_info, ('download', 'url'), expected_type=url_or_none) if download_url: diff --git a/yt_dlp/extractor/youtube.py b/yt_dlp/extractor/youtube.py index f7b0772df..aff89f8ac 100644 --- a/yt_dlp/extractor/youtube.py +++ b/yt_dlp/extractor/youtube.py @@ -745,7 +745,7 @@ def _extract_badges(self, renderer: dict): } badges = [] - for badge in traverse_obj(renderer, ('badges', ..., 'metadataBadgeRenderer'), default=[]): + for badge in traverse_obj(renderer, ('badges', ..., 'metadataBadgeRenderer')): badge_type = ( privacy_icon_map.get(traverse_obj(badge, ('icon', 'iconType'), expected_type=str)) or badge_style_map.get(traverse_obj(badge, 'style')) @@ -785,7 +785,7 @@ def _get_text(data, *path_list, max_runs=None): runs = item runs = runs[:min(len(runs), max_runs or len(runs))] - text = ''.join(traverse_obj(runs, (..., 'text'), expected_type=str, default=[])) + text = ''.join(traverse_obj(runs, (..., 'text'), expected_type=str)) if text: return text @@ -805,7 +805,7 @@ def _extract_thumbnails(data, *path_list): """ thumbnails = [] for path in path_list or [()]: - for thumbnail in traverse_obj(data, (*variadic(path), 'thumbnails', ...), default=[]): + for thumbnail in traverse_obj(data, (*variadic(path), 'thumbnails', ...)): thumbnail_url = url_or_none(thumbnail.get('url')) if not thumbnail_url: continue @@ -2668,11 +2668,10 @@ def refetch_manifest(format_id, delay): return _, _, prs, player_url = self._download_player_responses(url, smuggled_data, video_id, webpage_url) - video_details = traverse_obj( - prs, (..., 'videoDetails'), expected_type=dict, default=[]) + video_details = traverse_obj(prs, (..., 'videoDetails'), expected_type=dict) microformats = traverse_obj( prs, (..., 'microformat', 'playerMicroformatRenderer'), - expected_type=dict, default=[]) + expected_type=dict) _, live_status, _, formats, _ = self._list_formats(video_id, microformats, video_details, prs, player_url) is_live = live_status == 'is_live' start_time = time.time() @@ -3173,7 +3172,7 @@ def _extract_chapters_from_engagement_panel(self, data, duration): content_list = traverse_obj( data, ('engagementPanels', ..., 'engagementPanelSectionListRenderer', 'content', 'macroMarkersListRenderer', 'contents'), - expected_type=list, default=[]) + expected_type=list) chapter_time = lambda chapter: parse_duration(self._get_text(chapter, 'timeDescription')) chapter_title = lambda chapter: self._get_text(chapter, 'title') @@ -3450,7 +3449,7 @@ def _is_agegated(player_response): if traverse_obj(player_response, ('playabilityStatus', 'desktopLegacyAgeGateReason')): return True - reasons = traverse_obj(player_response, ('playabilityStatus', ('status', 'reason')), default=[]) + reasons = traverse_obj(player_response, ('playabilityStatus', ('status', 'reason'))) AGE_GATE_REASONS = ( 'confirm your age', 'age-restricted', 'inappropriate', # reason 'age_verification_required', 'age_check_required', # status @@ -3606,7 +3605,7 @@ def _extract_formats_and_subtitles(self, streaming_data, video_id, player_url, l 'audio_quality_ultralow', 'audio_quality_low', 'audio_quality_medium', 'audio_quality_high', # Audio only formats 'small', 'medium', 'large', 'hd720', 'hd1080', 'hd1440', 'hd2160', 'hd2880', 'highres' ]) - streaming_formats = traverse_obj(streaming_data, (..., ('formats', 'adaptiveFormats'), ...), default=[]) + streaming_formats = traverse_obj(streaming_data, (..., ('formats', 'adaptiveFormats'), ...)) for fmt in streaming_formats: if fmt.get('targetDurationSec'): @@ -3872,7 +3871,7 @@ def _list_formats(self, video_id, microformats, video_details, player_responses, else 'was_live' if live_content else 'not_live' if False in (is_live, live_content) else None) - streaming_data = traverse_obj(player_responses, (..., 'streamingData'), default=[]) + streaming_data = traverse_obj(player_responses, (..., 'streamingData')) *formats, subtitles = self._extract_formats_and_subtitles(streaming_data, video_id, player_url, live_status, duration) return live_broadcast_details, live_status, streaming_data, formats, subtitles @@ -3887,7 +3886,7 @@ def _real_extract(self, url): webpage, master_ytcfg, player_responses, player_url = self._download_player_responses(url, smuggled_data, video_id, webpage_url) playability_statuses = traverse_obj( - player_responses, (..., 'playabilityStatus'), expected_type=dict, default=[]) + player_responses, (..., 'playabilityStatus'), expected_type=dict) trailer_video_id = get_first( playability_statuses, @@ -3900,11 +3899,10 @@ def _real_extract(self, url): search_meta = ((lambda x: self._html_search_meta(x, webpage, default=None)) if webpage else (lambda x: None)) - video_details = traverse_obj( - player_responses, (..., 'videoDetails'), expected_type=dict, default=[]) + video_details = traverse_obj(player_responses, (..., 'videoDetails'), expected_type=dict) microformats = traverse_obj( player_responses, (..., 'microformat', 'playerMicroformatRenderer'), - expected_type=dict, default=[]) + expected_type=dict) translated_title = self._get_text(microformats, (..., 'title')) video_title = (self._preferred_lang and translated_title @@ -4110,10 +4108,10 @@ def get_lang_code(track): # Converted into dicts to remove duplicates captions = { get_lang_code(sub): sub - for sub in traverse_obj(pctr, (..., 'captionTracks', ...), default=[])} + for sub in traverse_obj(pctr, (..., 'captionTracks', ...))} translation_languages = { lang.get('languageCode'): self._get_text(lang.get('languageName'), max_runs=1) - for lang in traverse_obj(pctr, (..., 'translationLanguages', ...), default=[])} + for lang in traverse_obj(pctr, (..., 'translationLanguages', ...))} def process_language(container, base_url, lang_code, sub_name, query): lang_subs = container.setdefault(lang_code, []) @@ -4267,9 +4265,8 @@ def process_language(container, base_url, lang_code, sub_name, query): list) or []): tbrs = variadic( traverse_obj( - tlb, 'toggleButtonRenderer', - ('segmentedLikeDislikeButtonRenderer', ..., 'toggleButtonRenderer'), - default=[])) + tlb, ('toggleButtonRenderer', ...), + ('segmentedLikeDislikeButtonRenderer', ..., 'toggleButtonRenderer'))) for tbr in tbrs: for getter, regex in [( lambda x: x['defaultText']['accessibility']['accessibilityData'], diff --git a/yt_dlp/utils.py b/yt_dlp/utils.py index 878b2b6a8..7cf151e3a 100644 --- a/yt_dlp/utils.py +++ b/yt_dlp/utils.py @@ -5420,7 +5420,7 @@ def traverse_obj( Each of the provided `paths` is tested and the first producing a valid result will be returned. The next path will also be tested if the path branched but no results could be found. Supported values for traversal are `Mapping`, `Sequence` and `re.Match`. - Unhelpful values (`[]`, `{}`, `None`) are treated as the absence of a value and discarded. + Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. @@ -5484,7 +5484,7 @@ def apply_key(key, obj, is_last): branching = False result = None - if obj is None: + if obj is None and traverse_string: pass elif key is None: @@ -5558,14 +5558,13 @@ def apply_key(key, obj, is_last): result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None) elif isinstance(key, (int, slice)): - if not is_sequence(obj): - if traverse_string: - with contextlib.suppress(IndexError): - result = str(obj)[key] - else: + if is_sequence(obj): branching = isinstance(key, slice) with contextlib.suppress(IndexError): result = obj[key] + elif traverse_string: + with contextlib.suppress(IndexError): + result = str(obj)[key] return branching, result if branching else (result,) @@ -5617,7 +5616,7 @@ def apply_path(start_obj, path, test_type): def _traverse_obj(obj, path, allow_empty, test_type): results, has_branched, is_dict = apply_path(obj, path, test_type) - results = LazyList(item for item in results if item not in (None, [], {})) + results = LazyList(item for item in results if item not in (None, {})) if get_all and has_branched: if results: return results.exhaust()