Commit 0e7f2d60 authored by Roman Alifanov's avatar Roman Alifanov

Improve parameter passing for arrays, dicts and class instances

- Add proper type tracking for function and method parameters - DCE now tracks array-returning methods (keys, split, slice) - Fix split/slice to return arrays via __CT_RET_ARR - Add foreach support for inline split expressions - Detect object parameters by unknown method calls - Add comprehensive tests for parameter passing New test classes: - TestFunctionParameterPassing (11 tests) - TestClassInstancePassing (4 tests)
parent a9fe7b9d
......@@ -1010,12 +1010,35 @@ Error: Unknown method 'badMethod' for type 'fs'. Available: append, exists, list
```
CodeGenerator
├── StdlibMixin # stdlib.py — встроенные функции
├── AwkCodegenMixin # awk_codegen.py — @awk компиляция
├── ExprMixin # expr_codegen.py — выражения
├── StmtMixin # stmt_codegen.py — statements
├── ClassMixin # class_codegen.py — классы/методы
├── DecoratorMixin # decorator_codegen.py — декораторы
├── DispatchMixin # dispatch_codegen.py — диспатч, присваивания
└── CseMixin # cse_codegen.py — CSE оптимизации
├── CseMixin # cse_codegen.py — CSE оптимизации
├── StdlibMixin # stdlib.py — встроенные функции
└── AwkCodegenMixin # awk_codegen.py — @awk компиляция
Вспомогательные модули:
├── constants.py # Константы (RET_VAR, TMP_PREFIX, CLASS_FUNC_PREFIX, etc.)
└── methods.py # Единый реестр методов для bash/awk синхронизации
```
### Добавление новых методов
Для добавления нового метода достаточно обновить `methods.py`:
```python
STRING_METHODS = {
...
"new_method": MethodDef(
"new_method",
min_args=1,
max_args=1,
bash_func="__ct_str_new_method",
awk_gen=lambda obj, args: f"awk_impl({obj}, {args[0]})"
),
}
```
Bash и AWK codegen автоматически подхватят изменения.
......@@ -358,33 +358,35 @@ Uses `json.get()` for parsing Telegram API responses and `str.urlencode()` for U
## Project Structure
```
bootstrap/ # Bootstrap compiler (Python)
├── main.py # CLI entry point
├── lexer.py # Tokenizer
├── tokens.py # Token type definitions
├── parser.py # Recursive descent parser, AST generation
├── ast_nodes.py # AST node classes
├── codegen.py # Main Bash code generator
├── stmt_codegen.py # Statement generation (mixin)
├── expr_codegen.py # Expression generation (mixin)
├── class_codegen.py # Class/method generation (mixin)
bootstrap/ # Bootstrap compiler (Python)
├── main.py # CLI entry point
├── lexer.py # Tokenizer
├── tokens.py # Token type definitions
├── parser.py # Recursive descent parser, AST generation
├── ast_nodes.py # AST node classes
├── errors.py # Error handling
├── constants.py # Codegen constants (RET_VAR, TMP_PREFIX, etc.)
├── methods.py # Unified method registry for bash/awk sync
├── dce.py # Dead code elimination
├── codegen.py # Main Bash code generator (mixin coordinator)
├── expr_codegen.py # Expression generation (mixin)
├── stmt_codegen.py # Statement generation (mixin)
├── class_codegen.py # Class/method generation (mixin)
├── dispatch_codegen.py # Method dispatch, assignments (mixin)
├── decorator_codegen.py # Decorator wrappers (mixin)
├── awk_codegen.py # AWK generator for @awk (mixin)
├── stdlib.py # Standard library generation (mixin)
├── cse_codegen.py # Common subexpression elimination (mixin)
├── dce.py # Dead code elimination
└── errors.py # Error handling
lib/ # ContenT libraries
└── cli.ct # CLI library (urfave/cli style)
tests/ # Test suite
├── test_lexer.py # Lexer tests
├── test_parser.py # Parser tests
├── cse_codegen.py # Common subexpression elimination (mixin)
├── stdlib.py # Standard library generation (mixin)
└── awk_codegen.py # AWK generator for @awk (mixin)
lib/ # ContenT libraries
└── cli.ct # CLI library (urfave/cli style)
tests/ # Test suite
├── test_lexer.py # Lexer tests
├── test_parser.py # Parser tests
└── test_integration.py # Integration tests
examples/ # Example .ct programs
examples/ # Example .ct programs
```
## License
......
......@@ -358,25 +358,27 @@ python3 content run examples/telegram_echobot/echobot.ct
## Структура проекта
```
bootstrap/ # Bootstrap-компилятор (Python)
├── main.py # CLI точка входа
├── lexer.py # Токенизатор
├── tokens.py # Определения типов токенов
├── parser.py # Рекурсивный спуск, генерация AST
├── ast_nodes.py # Классы узлов AST
├── codegen.py # Основной генератор Bash-кода
├── stmt_codegen.py # Генерация statements (миксин)
├── expr_codegen.py # Генерация выражений (миксин)
├── class_codegen.py # Генерация классов/методов (миксин)
bootstrap/ # Bootstrap-компилятор (Python)
├── main.py # CLI точка входа
├── lexer.py # Токенизатор
├── tokens.py # Определения типов токенов
├── parser.py # Рекурсивный спуск, генерация AST
├── ast_nodes.py # Классы узлов AST
├── errors.py # Обработка ошибок
├── constants.py # Константы кодогенерации (RET_VAR, TMP_PREFIX, etc.)
├── methods.py # Единый реестр методов для bash/awk синхронизации
├── dce.py # Устранение мёртвого кода
├── codegen.py # Основной генератор Bash-кода (координатор миксинов)
├── expr_codegen.py # Генерация выражений (миксин)
├── stmt_codegen.py # Генерация statements (миксин)
├── class_codegen.py # Генерация классов/методов (миксин)
├── dispatch_codegen.py # Диспатч методов, присваивания (миксин)
├── decorator_codegen.py # Обёртки декораторов (миксин)
├── awk_codegen.py # AWK-генератор для @awk (миксин)
├── stdlib.py # Генерация стандартной библиотеки (миксин)
├── cse_codegen.py # Устранение общих подвыражений (миксин)
├── dce.py # Устранение мёртвого кода
└── errors.py # Обработка ошибок
├── cse_codegen.py # Устранение общих подвыражений (миксин)
├── stdlib.py # Генерация стандартной библиотеки (миксин)
└── awk_codegen.py # AWK-генератор для @awk (миксин)
lib/ # Библиотеки на ContenT
lib/ # Библиотеки на ContenT
└── cli.ct # CLI-библиотека (стиль urfave/cli)
tests/ # Тестовый набор
......
from .ast_nodes import *
from .ast_nodes import (
ClassDecl, FunctionDecl, ArrayLiteral, DictLiteral, NilLiteral, NewExpr,
CallExpr, Identifier, Assignment, MemberAccess, ThisExpr, ReturnStmt,
ConstructorDecl, Parameter, Block, ForeachStmt, IfStmt, WhileStmt, ForStmt,
ExpressionStmt, BinaryOp, IndexAccess
)
from .methods import ARRAY_METHODS, DICT_METHODS
ARRAY_ONLY_METHODS = {"push", "pop", "shift", "join", "slice", "map", "filter"}
DICT_ONLY_METHODS = {"has", "del", "keys"}
ARRAY_METHODS_ALL = ARRAY_ONLY_METHODS | {"get", "set", "len"}
DICT_METHODS_ALL = DICT_ONLY_METHODS | {"get", "set", "len"}
STRING_METHODS_ALL = {"upper", "lower", "trim", "len", "contains", "starts", "ends",
"index", "replace", "substr", "split", "charAt", "urlencode"}
ALL_KNOWN_METHODS = ARRAY_METHODS_ALL | DICT_METHODS_ALL | STRING_METHODS_ALL
class ClassMixin:
......@@ -176,9 +191,11 @@ class ClassMixin:
self.in_class_method = True
old_in_function = self.in_function
old_local_vars = self.local_vars.copy()
old_object_vars = self.object_vars.copy()
self.in_function = True
self.local_vars = set()
param_types = self._analyze_param_types(method)
for i, param in enumerate(method.params):
if param.is_variadic:
self.emit(f'local -a {param.name}=("${{@:{i + 1}}}")')
......@@ -189,6 +206,8 @@ class ClassMixin:
else:
self.emit(f'local {param.name}="${{{i + 1}}}"')
self.local_vars.add(param.name)
if param_types.get(param.name) == "object":
self.object_vars.add(param.name)
for stmt in method.body.statements:
self.generate_statement(stmt)
......@@ -196,6 +215,7 @@ class ClassMixin:
self.in_class_method = False
self.in_function = old_in_function
self.local_vars = old_local_vars
self.object_vars = old_object_vars
self.indent_level -= 1
self.emit("}")
self.emit()
......@@ -222,9 +242,11 @@ class ClassMixin:
self.in_class_method = True
old_in_function = self.in_function
old_local_vars = self.local_vars.copy()
old_object_vars = self.object_vars.copy()
self.in_function = True
self.local_vars = set()
param_types = self._analyze_param_types(method)
for i, param in enumerate(method.params):
if param.is_variadic:
self.emit(f'local -a {param.name}=("${{@:{i + 1}}}")')
......@@ -235,6 +257,8 @@ class ClassMixin:
else:
self.emit(f'local {param.name}="${{{i + 1}}}"')
self.local_vars.add(param.name)
if param_types.get(param.name) == "object":
self.object_vars.add(param.name)
for stmt in method.body.statements:
self.generate_statement(stmt)
......@@ -242,6 +266,7 @@ class ClassMixin:
self.in_class_method = False
self.in_function = old_in_function
self.local_vars = old_local_vars
self.object_vars = old_object_vars
self.indent_level -= 1
self.emit("}")
self.emit()
......@@ -313,6 +338,86 @@ class ClassMixin:
self.inlineable_methods[(cls.name, method.name)] = \
f'${{__CT_OBJ["$this.{arg0.member}"]:${{__CT_OBJ["$this.{arg1.member}"]}}:1}}'
def _analyze_param_types(self, func: FunctionDecl) -> dict:
"""Analyze function body to determine parameter types (array/dict/scalar)."""
param_names = {p.name for p in func.params}
param_types = {p.name: "scalar" for p in func.params}
param_methods = {p.name: set() for p in func.params}
def analyze_expr(expr):
if isinstance(expr, CallExpr) and isinstance(expr.callee, MemberAccess):
if isinstance(expr.callee.object, Identifier):
var_name = expr.callee.object.name
method = expr.callee.member
if var_name in param_names:
param_methods[var_name].add(method)
for arg in expr.arguments:
analyze_expr(arg)
elif isinstance(expr, BinaryOp):
analyze_expr(expr.left)
analyze_expr(expr.right)
elif isinstance(expr, IndexAccess):
if isinstance(expr.object, Identifier):
var_name = expr.object.name
if var_name in param_names and param_types[var_name] == "scalar":
param_types[var_name] = "array"
analyze_expr(expr.index)
def analyze_stmt(stmt):
if isinstance(stmt, Assignment):
analyze_expr(stmt.value)
if isinstance(stmt.target, IndexAccess):
if isinstance(stmt.target.object, Identifier):
var_name = stmt.target.object.name
if var_name in param_names:
param_types[var_name] = "array"
elif isinstance(stmt, ExpressionStmt):
analyze_expr(stmt.expression)
elif isinstance(stmt, ForeachStmt):
if isinstance(stmt.iterable, Identifier):
var_name = stmt.iterable.name
if var_name in param_names:
param_types[var_name] = "array"
if stmt.body:
for s in stmt.body.statements:
analyze_stmt(s)
elif isinstance(stmt, (IfStmt,)):
analyze_expr(stmt.condition)
if stmt.then_branch:
for s in stmt.then_branch.statements:
analyze_stmt(s)
for _, branch in stmt.elif_branches:
for s in branch.statements:
analyze_stmt(s)
if stmt.else_branch:
for s in stmt.else_branch.statements:
analyze_stmt(s)
elif isinstance(stmt, (WhileStmt, ForStmt)):
if hasattr(stmt, 'condition'):
analyze_expr(stmt.condition)
if stmt.body:
for s in stmt.body.statements:
analyze_stmt(s)
elif isinstance(stmt, ReturnStmt) and stmt.value:
analyze_expr(stmt.value)
if func.body:
for stmt in func.body.statements:
analyze_stmt(stmt)
for param_name, methods in param_methods.items():
unknown_methods = methods - ALL_KNOWN_METHODS
if unknown_methods:
param_types[param_name] = "object"
continue
if methods & ARRAY_ONLY_METHODS:
param_types[param_name] = "array"
elif methods & DICT_ONLY_METHODS:
param_types[param_name] = "dict"
elif methods & ARRAY_METHODS_ALL and not (methods & DICT_METHODS_ALL - {"get", "set", "len"}):
param_types[param_name] = "array"
return param_types
def generate_function(self, func: FunctionDecl):
test_decorator = None
......@@ -357,9 +462,23 @@ class ClassMixin:
self.emit(f"{name} () {{")
self.indent_level += 1
param_types = self._analyze_param_types(func)
old_param_name_map = getattr(self, 'param_name_map', {})
self.param_name_map = {}
for i, param in enumerate(func.params):
ptype = param_types.get(param.name, "scalar")
if param.is_variadic:
self.emit(f'local -a {param.name}=("${{@:{i + 1}}}")')
elif ptype in ("array", "dict"):
nameref_name = f"__ct_{func.name}_{param.name}"
self.emit(f'local -n {nameref_name}="${{{i + 1}}}"')
self.param_name_map[param.name] = nameref_name
if ptype == "array":
self.array_vars.add(nameref_name)
else:
self.dict_vars.add(nameref_name)
else:
if param.default is not None:
default_val = self.generate_expr(param.default)
......@@ -387,6 +506,7 @@ class ClassMixin:
self.deferred_calls = old_deferred
self.in_function = old_in_function
self.local_vars = old_local_vars
self.param_name_map = old_param_name_map
self.indent_level -= 1
self.emit("}")
......
from typing import List, Dict, Optional, Set
from .ast_nodes import *
from .ast_nodes import Program, ClassDecl, FunctionDecl, Assignment, Identifier
from .errors import ErrorCollector
from .stdlib import StdlibMixin
from .awk_codegen import AwkCodegenMixin
......@@ -52,6 +52,7 @@ class CodeGenerator(StdlibMixin, AwkCodegenMixin, ExprMixin, StmtMixin,
self.nameref_vars: Set[str] = set() # vars that are namerefs to arrays/dicts
self.instance_vars: Dict[str, str] = {} # var_name -> class_name
self.class_field_types: Dict[tuple, str] = {}
self.func_param_types: Dict[tuple, str] = {} # (func_name, param_name) -> "array"/"dict"
self.local_vars: Set[str] = set()
self.current_param_positions: Dict[str, int] = {} # param_name -> position (1-based)
......@@ -80,6 +81,21 @@ class CodeGenerator(StdlibMixin, AwkCodegenMixin, ExprMixin, StmtMixin,
def indent(self) -> str:
return " " * self.indent_level
class _IndentContext:
def __init__(self, gen):
self.gen = gen
def __enter__(self):
self.gen.indent_level += 1
return self
def __exit__(self, *_):
self.gen.indent_level -= 1
def indented(self):
"""Context manager for indented code blocks."""
return self._IndentContext(self)
def emit(self, line: str = ""):
if line:
self.output.append(f"{self.indent()}{line}")
......
"""Constants for bash code generation."""
RET_VAR = "__CT_RET"
TMP_PREFIX = "__ct_tmp_"
CLASS_FUNC_PREFIX = "__ct_class_"
LAMBDA_PREFIX = "__ct_lambda_"
OBJ_STORE = "__CT_OBJ"
THIS_INSTANCE = "__ct_this_instance"
ARR_FUNC_PREFIX = "__ct_arr_"
DICT_FUNC_PREFIX = "__ct_dict_"
STR_FUNC_PREFIX = "__ct_str_"
FH_FUNC_PREFIX = "__ct_fh_"
HTTP_FUNC_PREFIX = "__ct_http_"
FS_FUNC_PREFIX = "__ct_fs_"
JSON_FUNC_PREFIX = "__ct_json_"
REGEX_FUNC_PREFIX = "__ct_regex_"
MATH_FUNC_PREFIX = "__ct_math_"
from .ast_nodes import *
from typing import Dict, Any
from .ast_nodes import (
Expression, CallExpr, MemberAccess, ThisExpr, Identifier,
BinaryOp, UnaryOp, BoolLiteral
)
class NodeIdMap:
"""Mapping from AST nodes to values using id() with reference retention."""
def __init__(self):
self._map: Dict[int, Any] = {}
self._refs = []
def set(self, node, value):
self._refs.append(node)
self._map[id(node)] = value
def get(self, node, default=None):
return self._map.get(id(node), default)
def __contains__(self, node):
return id(node) in self._map
def __getitem__(self, node):
return self._map[id(node)]
class CseMixin:
......@@ -39,7 +64,7 @@ class CseMixin:
self.collect_method_calls(condition, calls)
seen = {}
mapping = {}
mapping = NodeIdMap()
regen_code = []
for call in calls:
......@@ -56,7 +81,7 @@ class CseMixin:
self.emit(assign_line)
seen[key] = temp
regen_code.append((call_line, assign_line))
mapping[id(call)] = seen[key]
mapping.set(call, seen[key])
return mapping, regen_code
......@@ -66,7 +91,7 @@ class CseMixin:
self.collect_all_calls(condition, calls)
seen = {}
mapping = {}
mapping = NodeIdMap()
regen_code = []
for call in calls:
......@@ -85,7 +110,7 @@ class CseMixin:
self.emit(assign_line)
seen[key] = temp
regen_code.append((call_line, assign_line))
mapping[id(call)] = seen[key]
mapping.set(call, seen[key])
elif isinstance(call.callee.object, Identifier):
obj_name = call.callee.object.name
......@@ -104,7 +129,7 @@ class CseMixin:
self.emit(call_line)
seen[key] = temp
regen_code.append((call_line, ""))
mapping[id(call)] = seen[key]
mapping.set(call, seen[key])
elif isinstance(call.callee, Identifier):
func_name = call.callee.name
......@@ -120,7 +145,7 @@ class CseMixin:
self.emit(assign_line)
seen[key] = temp
regen_code.append((call_line, assign_line))
mapping[id(call)] = seen[key]
mapping.set(call, seen[key])
return mapping, regen_code
......@@ -170,19 +195,19 @@ class CseMixin:
if isinstance(expr, BoolLiteral):
return "true" if expr.value else "false"
if isinstance(expr, CallExpr) and id(expr) in mapping:
return f'[[ "${{{mapping[id(expr)]}}}" == "true" ]]'
if isinstance(expr, CallExpr) and expr in mapping:
return f'[[ "${{{mapping[expr]}}}" == "true" ]]'
return self.generate_condition(expr)
def generate_expr_with_precompute(self, expr: Expression, mapping: dict) -> str:
def generate_expr_with_precompute(self, expr: Expression, mapping: NodeIdMap) -> str:
"""Generate expression using pre-computed values."""
if isinstance(expr, CallExpr) and id(expr) in mapping:
return f'${mapping[id(expr)]}'
if isinstance(expr, CallExpr) and expr in mapping:
return f'${mapping[expr]}'
if isinstance(expr, MemberAccess):
if isinstance(expr.object, CallExpr) and id(expr.object) in mapping:
temp = mapping[id(expr.object)]
if isinstance(expr.object, CallExpr) and expr.object in mapping:
temp = mapping[expr.object]
return f'${{__CT_OBJ["${temp}.{expr.member}"]}}'
return self.generate_expr(expr)
......
"""Dead Code Elimination."""
from .ast_nodes import *
import re
from .ast_nodes import (
ClassDecl, NewExpr, CallExpr, Identifier, FunctionDecl, Assignment,
ExpressionStmt, IfStmt, ForStmt, ForeachStmt, WhileStmt, WhenStmt,
WhenBranch, TryStmt, ThrowStmt, DeferStmt, ReturnStmt, ArrayLiteral,
DictLiteral, IndexAccess, Lambda, MemberAccess, ThisExpr, Block,
BinaryOp, UnaryOp, WithStmt, Program, StringLiteral
)
class UsageAnalyzer:
......@@ -11,19 +19,26 @@ class UsageAnalyzer:
'args', 'misc',
}
ARRAY_RETURNING_METHODS = {'keys', 'split', 'slice'}
DICT_RETURNING_METHODS = set()
def __init__(self):
self.used: set = set()
self.has_classes = False
self.has_awk = False
self.test_mode = False
self.defined_classes: dict = {}
self.defined_functions: dict = {}
self.used_classes: set = set()
self.used_methods: dict = {}
self.class_fields: dict = {}
self.variable_types: dict = {}
self.array_variables: set = set()
self.dict_variables: set = set()
self.current_class_name: str = None
self.current_method_name: str = None
self.method_calls: dict = {}
self.func_param_types: dict = {}
def analyze(self, programs: list, test_mode: bool = False) -> set:
self.used = {'core'}
......@@ -34,6 +49,8 @@ class UsageAnalyzer:
if isinstance(stmt, ClassDecl):
self.defined_classes[stmt.name] = stmt
self._collect_class_fields(stmt)
elif isinstance(stmt, FunctionDecl):
self.defined_functions[stmt.name] = stmt
for program in programs:
for stmt in program.statements:
......@@ -140,7 +157,10 @@ class UsageAnalyzer:
for dec in stmt.decorators:
if dec.name == 'awk':
self.has_awk = True
old_func_name = getattr(self, 'current_func_name', None)
self.current_func_name = stmt.name
self._analyze_body(stmt.body)
self.current_func_name = old_func_name
elif isinstance(stmt, Assignment):
self._analyze_expr(stmt.value)
......@@ -148,10 +168,21 @@ class UsageAnalyzer:
var_name = stmt.target.name
if isinstance(stmt.value, NewExpr):
self.variable_types[var_name] = stmt.value.class_name
elif isinstance(stmt.value, CallExpr) and isinstance(stmt.value.callee, Identifier):
callee_name = stmt.value.callee.name
if callee_name in self.defined_classes:
self.variable_types[var_name] = callee_name
elif isinstance(stmt.value, CallExpr):
if isinstance(stmt.value.callee, Identifier):
callee_name = stmt.value.callee.name
if callee_name in self.defined_classes:
self.variable_types[var_name] = callee_name
elif isinstance(stmt.value.callee, MemberAccess):
method = stmt.value.callee.member
if method in self.ARRAY_RETURNING_METHODS:
self.array_variables.add(var_name)
elif method in self.DICT_RETURNING_METHODS:
self.dict_variables.add(var_name)
elif isinstance(stmt.value, ArrayLiteral):
self.array_variables.add(var_name)
elif isinstance(stmt.value, DictLiteral):
self.dict_variables.add(var_name)
elif isinstance(stmt, ExpressionStmt):
self._analyze_expr(stmt.expression)
......@@ -262,6 +293,36 @@ class UsageAnalyzer:
elif isinstance(expr, Identifier):
pass
elif isinstance(expr, StringLiteral):
if getattr(expr, 'has_interpolation', False):
self._analyze_string_interpolation(expr.value)
def _analyze_string_interpolation(self, value: str):
"""Analyze method calls in string interpolation like {var.method()}."""
pattern = r'\{(\w+)\.(\w+)\s*\([^)]*\)\}'
for match in re.finditer(pattern, value):
var_name = match.group(1)
method = match.group(2)
if var_name in self.variable_types:
obj_class = self.variable_types[var_name]
if obj_class not in self.used_methods:
self.used_methods[obj_class] = set()
self.used_methods[obj_class].add(method)
elif hasattr(self, 'current_func_name') and self.current_func_name:
key = (self.current_func_name, var_name)
if key in self.func_param_types:
for obj_class in self.func_param_types[key]:
if obj_class not in self.used_methods:
self.used_methods[obj_class] = set()
self.used_methods[obj_class].add(method)
else:
for cls_name, cls_decl in self.defined_classes.items():
for m in cls_decl.methods:
if m.name == method:
if cls_name not in self.used_methods:
self.used_methods[cls_name] = set()
self.used_methods[cls_name].add(method)
def _analyze_call(self, expr: CallExpr):
callee = expr.callee
......@@ -274,6 +335,8 @@ class UsageAnalyzer:
self.used.add('array')
elif callee.name in ('random', 'random_range'):
self.used.add('misc')
elif callee.name in self.defined_functions:
self._analyze_function_call_with_types(callee.name, expr.arguments)
if isinstance(callee, MemberAccess):
if isinstance(callee.object, ThisExpr):
......@@ -315,6 +378,19 @@ class UsageAnalyzer:
if obj_class not in self.used_methods:
self.used_methods[obj_class] = set()
self.used_methods[obj_class].add(method)
elif ns in self.array_variables:
self.used.add('array')
elif ns in self.dict_variables:
self.used.add('dict')
elif hasattr(self, 'current_func_name') and self.current_func_name:
key = (self.current_func_name, ns)
if key in self.func_param_types:
for obj_class in self.func_param_types[key]:
if obj_class not in self.used_methods:
self.used_methods[obj_class] = set()
self.used_methods[obj_class].add(method)
else:
self._check_method(method)
elif ns == 'http':
self.used.add('http')
elif ns == 'fs':
......@@ -347,6 +423,46 @@ class UsageAnalyzer:
if not found_in_class:
self._check_method(method)
def _analyze_function_call_with_types(self, func_name: str, arguments: list):
"""Analyze function call and propagate object types to parameters."""
func_decl = self.defined_functions.get(func_name)
if not func_decl:
return
new_types_added = False
for i, arg in enumerate(arguments):
if i >= len(func_decl.params):
break
param_name = func_decl.params[i].name
arg_type = None
if isinstance(arg, Identifier):
arg_type = self.variable_types.get(arg.name)
if not arg_type and hasattr(self, 'current_func_name') and self.current_func_name:
key = (self.current_func_name, arg.name)
if key in self.func_param_types:
for t in self.func_param_types[key]:
arg_type = t
break
elif isinstance(arg, NewExpr):
arg_type = arg.class_name
elif isinstance(arg, CallExpr) and isinstance(arg.callee, Identifier):
if arg.callee.name in self.defined_classes:
arg_type = arg.callee.name
if arg_type and arg_type in self.defined_classes:
key = (func_name, param_name)
if key not in self.func_param_types:
self.func_param_types[key] = set()
if arg_type not in self.func_param_types[key]:
self.func_param_types[key].add(arg_type)
new_types_added = True
if new_types_added:
old_func = getattr(self, 'current_func_name', None)
self.current_func_name = func_name
self._analyze_body(func_decl.body)
self.current_func_name = old_func
def _analyze_member_access(self, expr: MemberAccess):
if isinstance(expr.object, Identifier):
pass
......
import re
from typing import List
from .ast_nodes import *
from .ast_nodes import Decorator, Parameter
class DecoratorMixin:
......
import re
from .ast_nodes import *
from .ast_nodes import (
Expression, IntegerLiteral, FloatLiteral, StringLiteral, BoolLiteral,
NilLiteral, Identifier, ThisExpr, ArrayLiteral, DictLiteral, BinaryOp,
UnaryOp, CallExpr, MemberAccess, IndexAccess, Lambda, NewExpr, BaseCall,
Block, ReturnStmt
)
class ExprMixin:
......@@ -22,7 +27,11 @@ class ExprMixin:
return ""
if isinstance(expr, Identifier):
return f"${{{expr.name}}}"
name = expr.name
param_map = getattr(self, 'param_name_map', {})
if name in param_map:
name = param_map[name]
return f"${{{name}}}"
if isinstance(expr, ThisExpr):
return "$this"
......
......@@ -176,7 +176,7 @@ def cmd_run (args):
finally:
try:
os.unlink (temp_path)
except:
except OSError:
pass
except Exception as e:
......@@ -238,7 +238,7 @@ def cmd_test (args):
finally:
try:
os.unlink (temp_path)
except:
except OSError:
pass
except Exception as e:
......
"""Unified method registry for bash and AWK code generation.
This module provides a single source of truth for all builtin methods,
ensuring consistency between bash and AWK code generators.
"""
from dataclasses import dataclass
from typing import Optional, Callable, List
@dataclass
class MethodDef:
"""Definition of a builtin method."""
name: str
min_args: int = 0
max_args: Optional[int] = None
bash_func: Optional[str] = None
awk_gen: Optional[Callable[[str, List[str]], str]] = None
returns_array: bool = False
STRING_METHODS = {
"len": MethodDef("len", 0, 0, "__ct_str_len",
lambda obj, args: f"length({obj})"),
"upper": MethodDef("upper", 0, 0, "__ct_str_upper",
lambda obj, args: f"toupper({obj})"),
"lower": MethodDef("lower", 0, 0, "__ct_str_lower",
lambda obj, args: f"tolower({obj})"),
"trim": MethodDef("trim", 0, 0, "__ct_str_trim",
lambda obj, args: f'(gsub(/^[ \\t]+|[ \\t]+$/, "", {obj}) ? {obj} : {obj})'),
"contains": MethodDef("contains", 1, 1, "__ct_str_contains",
lambda obj, args: f"(index({obj}, {args[0]}) > 0)"),
"starts": MethodDef("starts", 1, 1, "__ct_str_starts",
lambda obj, args: f"(substr({obj}, 1, length({args[0]})) == {args[0]})"),
"ends": MethodDef("ends", 1, 1, "__ct_str_ends",
lambda obj, args: f"(substr({obj}, length({obj}) - length({args[0]}) + 1) == {args[0]})"),
"index": MethodDef("index", 1, 1, "__ct_str_index",
lambda obj, args: f"(index({obj}, {args[0]}) - 1)"),
"replace": MethodDef("replace", 2, 2, "__ct_str_replace",
lambda obj, args: f"(gsub({args[0]}, {args[1]}, {obj}) ? {obj} : {obj})"),
"substr": MethodDef("substr", 2, 2, "__ct_str_substr",
lambda obj, args: f"substr({obj}, {args[0]} + 1, {args[1]})"),
"split": MethodDef("split", 1, 1, "__ct_str_split",
lambda obj, args: f"split({obj}, __split_arr, {args[0]})",
returns_array=True),
"charAt": MethodDef("charAt", 1, 1, "__ct_str_char_at",
lambda obj, args: f"substr({obj}, {args[0]} + 1, 1)"),
"urlencode": MethodDef("urlencode", 0, 0, "__ct_str_urlencode", None),
}
ARRAY_METHODS = {
"len": MethodDef("len", 0, 0, "__ct_arr_len",
lambda obj, args: f"length({obj})"),
"push": MethodDef("push", 1, 1, "__ct_arr_push",
lambda obj, args: f"{obj}[length({obj}) + 1] = {args[0]}"),
"pop": MethodDef("pop", 0, 0, "__ct_arr_pop",
lambda obj, args: f"delete {obj}[length({obj})]"),
"shift": MethodDef("shift", 0, 0, "__ct_arr_shift",
lambda obj, args: f"delete {obj}[1]"),
"join": MethodDef("join", 1, 1, "__ct_arr_join",
lambda obj, args: f"__ct_awk_join({obj}, {args[0]})"),
"get": MethodDef("get", 1, 1, "__ct_arr_get",
lambda obj, args: f"{obj}[{args[0]}]"),
"set": MethodDef("set", 2, 2, "__ct_arr_set",
lambda obj, args: f"{obj}[{args[0]}] = {args[1]}"),
"slice": MethodDef("slice", 2, 2, "__ct_arr_slice", None, returns_array=True),
"map": MethodDef("map", 1, 1, "__ct_arr_map", None, returns_array=True),
"filter": MethodDef("filter", 1, 1, "__ct_arr_filter", None, returns_array=True),
}
DICT_METHODS = {
"get": MethodDef("get", 1, 1, "__ct_dict_get",
lambda obj, args: f"{obj}[{args[0]}]"),
"set": MethodDef("set", 2, 2, "__ct_dict_set",
lambda obj, args: f"{obj}[{args[0]}] = {args[1]}"),
"has": MethodDef("has", 1, 1, "__ct_dict_has",
lambda obj, args: f"({args[0]} in {obj})"),
"del": MethodDef("del", 1, 1, "__ct_dict_del",
lambda obj, args: f"delete {obj}[{args[0]}]"),
"keys": MethodDef("keys", 0, 0, "__ct_dict_keys", None, returns_array=True),
}
FILE_HANDLE_METHODS = {
"read": MethodDef("read", 0, 0, "__ct_fh_read", None),
"readline": MethodDef("readline", 0, 0, "__ct_fh_readline", None),
"write": MethodDef("write", 1, 1, "__ct_fh_write", None),
"writeln": MethodDef("writeln", 1, 1, "__ct_fh_writeln", None),
"close": MethodDef("close", 0, 0, "__ct_fh_close", None),
}
NAMESPACE_METHODS = {
"fs": {"read", "write", "append", "exists", "remove", "mkdir", "list", "open"},
"http": {"get", "post", "put", "delete"},
"json": {"parse", "stringify", "get"},
"logger": {"info", "warn", "error", "debug"},
"regex": {"match", "extract"},
"args": {"count", "get"},
"shell": {"exec", "capture", "source"},
"time": {"now", "ms"},
"math": {"add", "sub", "mul", "div", "mod", "min", "max", "abs"},
}
BUILTIN_NAMESPACES = set(NAMESPACE_METHODS.keys())
BUILTIN_FUNCS = {"print", "exit", "len", "range", "ngrep", "is_number",
"is_empty", "chr", "ord", "assert", "assert_eq", "random", "random_range"}
def get_method_names(type_name: str) -> set:
"""Get all available method names for a type."""
if type_name == "string":
return set(STRING_METHODS.keys())
elif type_name == "array":
return set(ARRAY_METHODS.keys())
elif type_name == "dict":
return set(DICT_METHODS.keys())
elif type_name == "file_handle":
return set(FILE_HANDLE_METHODS.keys())
return set()
def get_method_def(type_name: str, method_name: str) -> Optional[MethodDef]:
"""Get method definition by type and name."""
methods = {
"string": STRING_METHODS,
"array": ARRAY_METHODS,
"dict": DICT_METHODS,
"file_handle": FILE_HANDLE_METHODS,
}
return methods.get(type_name, {}).get(method_name)
def get_bash_func(type_name: str, method_name: str) -> Optional[str]:
"""Get bash function name for a method."""
method = get_method_def(type_name, method_name)
return method.bash_func if method else None
def generate_awk(type_name: str, method_name: str, obj: str, args: List[str]) -> Optional[str]:
"""Generate AWK code for a method call."""
method = get_method_def(type_name, method_name)
if method and method.awk_gen:
return method.awk_gen(obj, args)
return None
from typing import List, Optional, Callable
from typing import List, Optional, Callable, Union
from .tokens import Token, TokenType
from .ast_nodes import *
from .ast_nodes import (
SourceLocation, Program, Declaration, Statement, Decorator, FunctionDecl,
Parameter, ClassDecl, ConstructorDecl, ImportStmt, Block, ReturnStmt,
BreakStmt, ContinueStmt, IfStmt, WhileStmt, ForStmt, ForeachStmt, WithStmt,
TryStmt, ThrowStmt, DeferStmt, WhenStmt, WhenBranch, RangePattern,
ExpressionStmt, Assignment, IntegerLiteral, FloatLiteral, StringLiteral,
BoolLiteral, NilLiteral, ThisExpr, ArrayLiteral, DictLiteral, Identifier,
BinaryOp, UnaryOp, CallExpr, MemberAccess, IndexAccess, NewExpr, Lambda,
BaseCall, Expression
)
from .errors import CompileError, ErrorCollector
......@@ -741,7 +750,7 @@ class Parser:
return self.parse_lambda_body (params, loc)
self.pos = saved_pos
except:
except Exception:
self.pos = saved_pos
expr = self.parse_expression ()
......
......@@ -458,7 +458,7 @@ class StdlibMixin:
self.emit ("__ct_str_starts () { [[ \"$1\" == \"$2\"* ]] && __CT_RET=true || __CT_RET=false; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_ends () { [[ \"$1\" == *\"$2\" ]] && __CT_RET=true || __CT_RET=false; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_replace () { __CT_RET=\"${1//\"$2\"/\"$3\"}\"; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_split () { local IFS=\"$2\"; read -ra __arr <<< \"$1\"; printf '%s\\n' \"${__arr[@]}\"; }")
self.emit ("__ct_str_split () { local IFS=\"$2\"; read -ra __CT_RET_ARR <<< \"$1\"; }")
self.emit ("__ct_str_trim () { local s=\"$1\"; s=\"${s#\"${s%%[![:space:]]*}\"}\"; __CT_RET=\"${s%\"${s##*[![:space:]]}\"}\" ; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_upper () { __CT_RET=\"${1^^}\"; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_lower () { __CT_RET=\"${1,,}\"; echo \"$__CT_RET\"; }")
......@@ -482,7 +482,7 @@ class StdlibMixin:
self.emit ("__ct_arr_len () { local -n __a=$1; __CT_RET=${#__a[@]}; echo \"$__CT_RET\"; }")
self.emit ("__ct_arr_get () { local -n __a=$1; __CT_RET=\"${__a[$2]}\"; echo \"$__CT_RET\"; }")
self.emit ("__ct_arr_set () { local -n __a=$1; __a[$2]=\"$3\"; }")
self.emit ("__ct_arr_slice () { local -n __a=$1; local -a __r=(\"${__a[@]:$2:$3}\"); printf '%s\\n' \"${__r[@]}\"; }")
self.emit ("__ct_arr_slice () { local -n __a=$1; __CT_RET_ARR=(\"${__a[@]:$2:$3}\"); }")
self.emit ()
self.emit ("# Array map/filter with lambda functions")
......@@ -628,7 +628,8 @@ class StdlibMixin:
self.emit ('__ct_dict_get () { local -n __d="$1"; __CT_RET="${__d[$2]}"; echo "$__CT_RET"; }')
self.emit ('__ct_dict_has () { local -n __d="$1"; [[ -v "__d[$2]" ]] && __CT_RET=true || __CT_RET=false; echo "$__CT_RET"; }')
self.emit ('__ct_dict_del () { local -n __d="$1"; unset "__d[$2]"; }')
self.emit ('__ct_dict_keys () { local -n __d="$1"; printf \'%s\\n\' "${!__d[@]}"; }')
self.emit ('__ct_dict_keys () { local -n __d="$1"; __CT_RET_ARR=("${!__d[@]}"); }')
self.emit ('__ct_dict_len () { local -n __d="$1"; __CT_RET=${#__d[@]}; echo "$__CT_RET"; }')
self.emit ()
def _emit_misc (self):
......
from .ast_nodes import *
from .ast_nodes import (
FunctionDecl, ClassDecl, ImportStmt, Assignment, ExpressionStmt, IfStmt,
WhileStmt, ForStmt, ForeachStmt, WithStmt, TryStmt, ThrowStmt, DeferStmt,
WhenStmt, RangePattern, ReturnStmt, BreakStmt, ContinueStmt, Block,
CallExpr, Identifier, MemberAccess, ThisExpr, StringLiteral, NewExpr,
BinaryOp, DictLiteral, ArrayLiteral, WhenBranch
)
class StmtMixin:
......@@ -250,6 +256,8 @@ class StmtMixin:
if isinstance(stmt.iterable, Identifier):
arr_name = stmt.iterable.name
param_map = getattr(self, 'param_name_map', {})
arr_name = param_map.get(arr_name, arr_name)
if len(stmt.variables) == 1:
var = stmt.variables[0]
self.emit(f'for {var} in "${{{arr_name}[@]}}"; do')
......@@ -293,6 +301,27 @@ class StmtMixin:
self.emit("done")
return
if stmt.iterable.callee.member == "split" and len(stmt.iterable.arguments) == 1:
str_expr = self.generate_expr(stmt.iterable.callee.object)
delim_arg = self.generate_expr(stmt.iterable.arguments[0])
var = stmt.variables[0]
self.emit(f'__ct_str_split "{str_expr}" "{delim_arg}"')
if len(stmt.variables) == 1:
self.emit(f'for {var} in "${{__CT_RET_ARR[@]}}"; do')
else:
idx_var = stmt.variables[0]
val_var = stmt.variables[1]
self.emit(f'{idx_var}=0')
self.emit(f'for {val_var} in "${{__CT_RET_ARR[@]}}"; do')
self.indent_level += 1
for s in stmt.body.statements:
self.generate_statement(s)
if len(stmt.variables) == 2:
self.emit(f'((++{stmt.variables[0]}))')
self.indent_level -= 1
self.emit("done")
return
iterable = self.generate_expr(stmt.iterable)
var = stmt.variables[0]
self.emit(f'for {var} in {iterable}; do')
......@@ -569,26 +598,29 @@ class StmtMixin:
args_str = " ".join([f'"{a}"' for a in args])
if obj_name in self.array_vars and method in arr_methods:
param_map = getattr(self, 'param_name_map', {})
mapped_name = param_map.get(obj_name, obj_name)
if mapped_name in self.array_vars and method in arr_methods:
func_name = arr_methods[method]
self.emit(f'{func_name} "{obj_name}" {args_str} >/dev/null'.replace(' ', ' '))
self.emit(f'{func_name} "{mapped_name}" {args_str} >/dev/null'.replace(' ', ' '))
self.emit('echo "$__CT_RET"')
self.emit("return 0")
return True
elif obj_name in self.dict_vars and method in dict_methods:
elif mapped_name in self.dict_vars and method in dict_methods:
func_name = dict_methods[method]
self.emit(f'{func_name} "{obj_name}" {args_str} >/dev/null'.replace(' ', ' '))
self.emit(f'{func_name} "{mapped_name}" {args_str} >/dev/null'.replace(' ', ' '))
self.emit('echo "$__CT_RET"')
self.emit("return 0")
return True
elif method in str_methods:
func_name = str_methods[method]
self.emit(f'{func_name} "${{{obj_name}}}" {args_str} >/dev/null'.replace(' ', ' '))
self.emit(f'{func_name} "${{{mapped_name}}}" {args_str} >/dev/null'.replace(' ', ' '))
self.emit('echo "$__CT_RET"')
self.emit("return 0")
return True
self.emit(f'__ct_call_method "${{{obj_name}}}" "{method}" {args_str} >/dev/null')
self.emit(f'__ct_call_method "${{{mapped_name}}}" "{method}" {args_str} >/dev/null')
self.emit('echo "$__CT_RET"')
self.emit("return 0")
return True
......@@ -1077,3 +1077,330 @@ x = arr.foo()
assert "Available:" in stdout
assert "push" in stdout
assert "pop" in stdout
class TestFunctionParameterPassing:
"""Tests for passing arrays and dicts to user-defined functions."""
def test_array_modification_in_function(self):
code, stdout, _ = run_ct('''
func modify_arr(arr) {
arr.push(99)
arr.push(100)
}
nums = [1, 2, 3]
print("Before: {nums.len()}")
modify_arr(nums)
print("After: {nums.len()}")
''')
assert code == 0
assert "Before: 3" in stdout
assert "After: 5" in stdout
def test_array_return_length_from_function(self):
code, stdout, _ = run_ct('''
func get_len(arr) {
return arr.len()
}
nums = [1, 2, 3, 4, 5]
len = get_len(nums)
print("Length: {len}")
''')
assert code == 0
assert "Length: 5" in stdout
def test_dict_modification_in_function(self):
code, stdout, _ = run_ct('''
func add_key(d, key, value) {
d.set(key, value)
}
config = {"host": "localhost"}
add_key(config, "port", "8080")
port = config.get("port")
print("Port: {port}")
''')
assert code == 0
assert "Port: 8080" in stdout
def test_dict_has_in_function(self):
code, stdout, _ = run_ct('''
func has_key(d, key) {
return d.has(key)
}
config = {"host": "localhost", "port": "8080"}
has_host = has_key(config, "host")
has_debug = has_key(config, "debug")
print("Has host: {has_host}")
print("Has debug: {has_debug}")
''')
assert code == 0
assert "Has host: true" in stdout
assert "Has debug: false" in stdout
def test_array_get_in_function(self):
code, stdout, _ = run_ct('''
func first_element(arr) {
return arr.get(0)
}
nums = [42, 100, 200]
first = first_element(nums)
print("First: {first}")
''')
assert code == 0
assert "First: 42" in stdout
def test_array_slice_in_function(self):
code, stdout, _ = run_ct('''
func double_all(arr) {
for i in range(arr.len()) {
val = arr.get(i)
arr.set(i, val * 2)
}
}
nums = [1, 2, 3]
double_all(nums)
print("{nums.get(0)} {nums.get(1)} {nums.get(2)}")
''')
assert code == 0
assert "2 4 6" in stdout
def test_nested_function_array_passing(self):
code, stdout, _ = run_ct('''
func inner(arr) {
arr.push("inner")
}
func outer(arr) {
arr.push("outer")
inner(arr)
}
items = []
outer(items)
print("Items: {items.len()}")
''')
assert code == 0
assert "Items: 2" in stdout
def test_dict_keys_in_function(self):
code, stdout, _ = run_ct('''
func count_keys(d) {
keys = d.keys()
return keys.len()
}
config = {"a": "1", "b": "2", "c": "3"}
count = count_keys(config)
print("Key count: {count}")
''')
assert code == 0
assert "Key count: 3" in stdout
def test_mixed_parameters_scalar_and_array(self):
code, stdout, _ = run_ct('''
func add_items(prefix, arr, count) {
for i in range(count) {
arr.push("{prefix}_{i}")
}
}
items = []
add_items("item", items, 3)
print("Count: {items.len()}")
''')
assert code == 0
assert "Count: 3" in stdout
def test_dict_get_in_function(self):
code, stdout, _ = run_ct('''
func get_value(d, key) {
if d.has(key) {
return d.get(key)
}
return ""
}
data = {"name": "Alice", "age": "30"}
name = get_value(data, "name")
print("Name: {name}")
''')
assert code == 0
assert "Name: Alice" in stdout
def test_array_foreach_in_function(self):
code, stdout, _ = run_ct('''
func sum_array(arr) {
total = 0
foreach n in arr {
total += n
}
return total
}
nums = [1, 2, 3, 4, 5]
result = sum_array(nums)
print("Sum: {result}")
''')
assert code == 0
assert "Sum: 15" in stdout
def test_string_split_returns_array(self):
code, stdout, _ = run_ct('''
func count_words(text) {
words = text.split(" ")
return words.len()
}
sentence = "hello world foo bar"
count = count_words(sentence)
print("Word count: {count}")
''')
assert code == 0
assert "Word count: 4" in stdout
def test_array_slice_returns_array(self):
code, stdout, _ = run_ct('''
func first_two(arr) {
sub = arr.slice(0, 2)
return sub.len()
}
nums = [1, 2, 3, 4, 5]
count = first_two(nums)
print("Slice len: {count}")
''')
assert code == 0
assert "Slice len: 2" in stdout
class TestClassInstancePassing:
"""Tests for passing class instances to functions and methods."""
def test_class_to_function_basic(self):
code, stdout, _ = run_ct('''
class Counter {
count = 0
construct(initial) {
this.count = initial
}
func increment() {
this.count += 1
}
func get() {
return this.count
}
}
func add_ten(c) {
for i in range(10) {
c.increment()
}
}
counter = new Counter(5)
add_ten(counter)
print("Result: {counter.get()}")
''')
assert code == 0
assert "Result: 15" in stdout
def test_class_to_function_nested(self):
code, stdout, _ = run_ct('''
class Counter {
count = 0
construct(initial) {
this.count = initial
}
func increment() {
this.count += 1
}
func get() {
return this.count
}
}
func inner(c) {
c.increment()
return c.get()
}
func outer(c) {
c.increment()
return inner(c)
}
counter = new Counter(0)
result = outer(counter)
print("Result: {result}")
''')
assert code == 0
assert "Result: 2" in stdout
def test_class_to_method(self):
code, stdout, _ = run_ct('''
class Counter {
count = 0
construct(initial) {
this.count = initial
}
func get() {
return this.count
}
func add(amount) {
this.count += amount
}
}
class Calculator {
func double_counter(c) {
val = c.get()
c.add(val)
}
}
counter = new Counter(5)
calc = new Calculator()
calc.double_counter(counter)
print("Result: {counter.get()}")
''')
assert code == 0
assert "Result: 10" in stdout
def test_multiple_class_params_to_method(self):
code, stdout, _ = run_ct('''
class Counter {
count = 0
construct(initial) {
this.count = initial
}
func get() {
return this.count
}
func add(amount) {
this.count += amount
}
}
class Calculator {
func process(c1, c2) {
v1 = c1.get()
v2 = c2.get()
c1.add(v2)
c2.add(v1)
}
}
a = new Counter(5)
b = new Counter(10)
calc = new Calculator()
calc.process(a, b)
print("a={a.get()}, b={b.get()}")
''')
assert code == 0
assert "a=15, b=15" in stdout
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment