File size: 2,398 Bytes
43cd37c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# html_to_markdown/ast_utils.py

from typing import Callable, Optional, List, Union
from s_types import SemanticMarkdownAST

def find_in_ast(ast: Union[SemanticMarkdownAST, List[SemanticMarkdownAST]], predicate: Callable[[SemanticMarkdownAST], bool]) -> Optional[SemanticMarkdownAST]:
    if isinstance(ast, list):
        for node in ast:
            result = find_in_ast(node, predicate)
            if result:
                return result
    else:
        if predicate(ast):
            return ast
        # Recursively search based on node type
        if hasattr(ast, 'content'):
            content = ast.content
            if isinstance(content, list):
                result = find_in_ast(content, predicate)
                if result:
                    return result
            elif isinstance(content, SemanticMarkdownAST):
                result = find_in_ast(content, predicate)
                if result:
                    return result
        if hasattr(ast, 'items'):
            for item in ast.items:
                result = find_in_ast(item, predicate)
                if result:
                    return result
        if hasattr(ast, 'rows'):
            for row in ast.rows:
                result = find_in_ast(row, predicate)
                if result:
                    return result
    return None

def find_all_in_ast(ast: Union[SemanticMarkdownAST, List[SemanticMarkdownAST]], predicate: Callable[[SemanticMarkdownAST], bool]) -> List[SemanticMarkdownAST]:
    results = []
    if isinstance(ast, list):
        for node in ast:
            results.extend(find_all_in_ast(node, predicate))
    else:
        if predicate(ast):
            results.append(ast)
        # Recursively search based on node type
        if hasattr(ast, 'content'):
            content = ast.content
            if isinstance(content, list):
                results.extend(find_all_in_ast(content, predicate))
            elif isinstance(content, SemanticMarkdownAST):
                results.extend(find_all_in_ast(content, predicate))
        if hasattr(ast, 'items'):
            for item in ast.items:
                results.extend(find_all_in_ast(item, predicate))
        if hasattr(ast, 'rows'):
            for row in ast.rows:
                results.extend(find_all_in_ast(row, predicate))
    return results