Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion core/ast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
This module provides the node types and classes for representing SQL query structures.
"""

from .node_type import NodeType
from .enums import NodeType
from .node import (
Node,
TableNode,
Expand All @@ -18,6 +18,7 @@
SelectNode,
FromNode,
WhereNode,
JoinNode,
GroupByNode,
HavingNode,
OrderByNode,
Expand All @@ -40,9 +41,11 @@
'SelectNode',
'FromNode',
'WhereNode',
'JoinNode',
'GroupByNode',
'HavingNode',
'OrderByNode',
'OrderByItemNode',
'LimitNode',
'OffsetNode',
'QueryNode'
Expand Down
60 changes: 60 additions & 0 deletions core/ast/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from enum import Enum

# ============================================================================
# Node Type Enumeration
# ============================================================================

class NodeType(Enum):
"""Node type enumeration"""

# Operands
TABLE = "table"
SUBQUERY = "subquery"
COLUMN = "column"
LITERAL = "literal"
# VarSQL specific
VAR = "var"
VARSET = "varset"

# Operators
OPERATOR = "operator"
FUNCTION = "function"

# Query structure
SELECT = "select"
FROM = "from"
WHERE = "where"
JOIN = "join"
GROUP_BY = "group_by"
HAVING = "having"
ORDER_BY = "order_by"
ORDER_BY_ITEM = "order_by_item"
LIMIT = "limit"
OFFSET = "offset"
QUERY = "query"

# ============================================================================
# Join Type Enumeration
# ============================================================================

class JoinType(Enum):
"""Join type enumeration"""
INNER = "inner"
OUTER = "outer"
LEFT = "left"
RIGHT = "right"
FULL = "full"
CROSS = "cross"
NATURAL = "natural"
SEMI = "semi"
ANTI = "anti"


# ============================================================================
# Sort Order Enumeration
# ============================================================================

class SortOrder(Enum):
"""Sort order enum"""
ASC = "ASC"
DESC = "DESC"
145 changes: 137 additions & 8 deletions core/ast/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Set, Optional
from abc import ABC

from .node_type import NodeType
from .enums import NodeType, JoinType, SortOrder

# ============================================================================
# Base Node Structure
Expand All @@ -13,6 +13,31 @@ class Node(ABC):
def __init__(self, type: NodeType, children: Optional[Set['Node']|List['Node']] = None):
self.type = type
self.children = children if children is not None else set()

def __eq__(self, other):
if not isinstance(other, Node):
return False
if self.type != other.type:
return False
if len(self.children) != len(other.children):
return False
# Compare children
if isinstance(self.children, set) and isinstance(other.children, set):
return self.children == other.children
elif isinstance(self.children, list) and isinstance(other.children, list):
return self.children == other.children
else:
return False

def __hash__(self):
# Make nodes hashable by using their type and a hash of their children
if isinstance(self.children, set):
# For sets, create a deterministic hash by sorting children by their string representation
children_hash = hash(tuple(sorted(self.children, key=lambda x: str(x))))
else:
# For lists, just hash the tuple directly
children_hash = hash(tuple(self.children))
return hash((self.type, children_hash))


# ============================================================================
Expand All @@ -25,6 +50,16 @@ def __init__(self, _name: str, _alias: Optional[str] = None, **kwargs):
super().__init__(NodeType.TABLE, **kwargs)
self.name = _name
self.alias = _alias

def __eq__(self, other):
if not isinstance(other, TableNode):
return False
return (super().__eq__(other) and
self.name == other.name and
self.alias == other.alias)

def __hash__(self):
return hash((super().__hash__(), self.name, self.alias))


# TODO - including query structure arguments (similar to QueryNode) in constructor.
Expand All @@ -43,6 +78,17 @@ def __init__(self, _name: str, _alias: Optional[str] = None, _parent_alias: Opti
self.alias = _alias
self.parent_alias = _parent_alias
self.parent = _parent

def __eq__(self, other):
if not isinstance(other, ColumnNode):
return False
return (super().__eq__(other) and
self.name == other.name and
self.alias == other.alias and
self.parent_alias == other.parent_alias)

def __hash__(self):
return hash((super().__hash__(), self.name, self.alias, self.parent_alias))


class LiteralNode(Node):
Expand All @@ -51,6 +97,15 @@ def __init__(self, _value: str|int|float|bool|datetime|None, **kwargs):
super().__init__(NodeType.LITERAL, **kwargs)
self.value = _value

def __eq__(self, other):
if not isinstance(other, LiteralNode):
return False
return (super().__eq__(other) and
self.value == other.value)

def __hash__(self):
return hash((super().__hash__(), self.value))


class VarNode(Node):
"""VarSQL variable node"""
Expand All @@ -72,37 +127,78 @@ def __init__(self, _left: Node, _name: str, _right: Optional[Node] = None, **kwa
children = [_left, _right] if _right else [_left]
super().__init__(NodeType.OPERATOR, children=children, **kwargs)
self.name = _name

def __eq__(self, other):
if not isinstance(other, OperatorNode):
return False
return (super().__eq__(other) and
self.name == other.name)

def __hash__(self):
return hash((super().__hash__(), self.name))


class FunctionNode(Node):
"""Function call node"""
def __init__(self, _name: str, _args: Optional[List[Node]] = None, **kwargs):
def __init__(self, _name: str, _args: Optional[List[Node]] = None, _alias: Optional[str] = None, **kwargs):
if _args is None:
_args = []
super().__init__(NodeType.FUNCTION, children=_args, **kwargs)
self.name = _name

self.alias = _alias

def __eq__(self, other):
if not isinstance(other, FunctionNode):
return False
return (super().__eq__(other) and
self.name == other.name and
self.alias == other.alias)

def __hash__(self):
return hash((super().__hash__(), self.name, self.alias))


class JoinNode(Node):
"""JOIN clause node"""
def __init__(self, _left_table: 'TableNode', _right_table: 'TableNode', _join_type: JoinType = JoinType.INNER, _on_condition: Optional['Node'] = None, **kwargs):
children = [_left_table, _right_table]
if _on_condition:
children.append(_on_condition)
super().__init__(NodeType.JOIN, children=children, **kwargs)
self.left_table = _left_table
self.right_table = _right_table
self.join_type = _join_type
self.on_condition = _on_condition

def __eq__(self, other):
if not isinstance(other, JoinNode):
return False
return (super().__eq__(other) and
self.join_type == other.join_type)

def __hash__(self):
return hash((super().__hash__(), self.join_type))

# ============================================================================
# Query Structure Nodes
# ============================================================================

class SelectNode(Node):
"""SELECT clause node"""
def __init__(self, _items: Set['Node'], **kwargs):
def __init__(self, _items: List['Node'], **kwargs):
super().__init__(NodeType.SELECT, children=_items, **kwargs)


# TODO - confine the valid NodeTypes as children of FromNode
class FromNode(Node):
"""FROM clause node"""
def __init__(self, _sources: Set['Node'], **kwargs):
def __init__(self, _sources: List['Node'], **kwargs):
super().__init__(NodeType.FROM, children=_sources, **kwargs)


class WhereNode(Node):
"""WHERE clause node"""
def __init__(self, _predicates: Set['Node'], **kwargs):
def __init__(self, _predicates: List['Node'], **kwargs):
super().__init__(NodeType.WHERE, children=_predicates, **kwargs)


Expand All @@ -114,13 +210,28 @@ def __init__(self, _items: List['Node'], **kwargs):

class HavingNode(Node):
"""HAVING clause node"""
def __init__(self, _predicates: Set['Node'], **kwargs):
def __init__(self, _predicates: List['Node'], **kwargs):
super().__init__(NodeType.HAVING, children=_predicates, **kwargs)


class OrderByItemNode(Node):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between OrderByItemNode and OrderByNode?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OrdrByItemNode is for single items in ORDER BY clauses like
OrderByItemNode(ColumnNode("salary"), _sort = SortOrder.DESC).
OrderByNode is used to build full ORDER BY clauses formed by multiple OrderByItemNode like

OrderByNode([
    OrderByItemNode(ColumnNode("salary"), _sort=SortOrder.DESC), 
    OrderByItemNode(ColumnNode("age"), _sort=SortOrder.ASC)
])

"""Single ORDER BY item"""
def __init__(self, _column: Node, _sort: SortOrder = SortOrder.ASC, **kwargs):
super().__init__(NodeType.ORDER_BY_ITEM, children=[_column], **kwargs)
self.sort = _sort

def __eq__(self, other):
if not isinstance(other, OrderByItemNode):
return False
return (super().__eq__(other) and
self.sort == other.sort)

def __hash__(self):
return hash((super().__hash__(), self.sort))

class OrderByNode(Node):
"""ORDER BY clause node"""
def __init__(self, _items: List['Node'], **kwargs):
def __init__(self, _items: List[OrderByItemNode], **kwargs):
super().__init__(NodeType.ORDER_BY, children=_items, **kwargs)


Expand All @@ -129,13 +240,31 @@ class LimitNode(Node):
def __init__(self, _limit: int, **kwargs):
super().__init__(NodeType.LIMIT, **kwargs)
self.limit = _limit

def __eq__(self, other):
if not isinstance(other, LimitNode):
return False
return (super().__eq__(other) and
self.limit == other.limit)

def __hash__(self):
return hash((super().__hash__(), self.limit))


class OffsetNode(Node):
"""OFFSET clause node"""
def __init__(self, _offset: int, **kwargs):
super().__init__(NodeType.OFFSET, **kwargs)
self.offset = _offset

def __eq__(self, other):
if not isinstance(other, OffsetNode):
return False
return (super().__eq__(other) and
self.offset == other.offset)

def __hash__(self):
return hash((super().__hash__(), self.offset))


class QueryNode(Node):
Expand Down
32 changes: 0 additions & 32 deletions core/ast/node_type.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/test_query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode,
OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode
)
from core.ast.node_type import NodeType
from core.ast.enums import NodeType, JoinType, SortOrder
from data.queries import get_query

parser = QueryParser()
Expand Down