Commit a9fe7b9d authored by Roman Alifanov's avatar Roman Alifanov

Fix critical codegen bugs for complex expressions

- String interpolation with operators: {a == b} now generates proper bash comparison instead of invalid ${a == b} - Double method call in conditions: CSE precompute now correctly reuses temp variable for CallExpr to avoid calling method twice - arr.push() with method call: detect side effects and call method separately to avoid subshell state isolation - charAt() with newline: use printf X marker to preserve trailing newlines - this.field.method() and obj.field.method(): handle nested MemberAccess in return statements and assignments - Namespace vs variable collision: check if identifier is known variable before treating as namespace - return func() with arrays: use _generate_call_arg for proper nameref
parent dbd5bb5d
......@@ -88,7 +88,7 @@ class ClassMixin:
else:
self.emit(f'declare -ga "${{__ct_this_instance}}_{field_name}=()"')
elif isinstance(default_value, DictLiteral):
self.emit(f'declare -gA "${{__ct_this_instance}}_{field_name}=()"')
self.emit(f'eval "declare -gA ${{__ct_this_instance}}_{field_name}=()"')
self.emit(f'__CT_OBJ["$__ct_this_instance.{field_name}"]="${{__ct_this_instance}}_{field_name}"')
elif default_value:
val = self.generate_expr(default_value)
......@@ -121,10 +121,19 @@ class ClassMixin:
self.in_class_method = True
old_in_function = self.in_function
old_local_vars = self.local_vars.copy()
old_param_positions = self.current_param_positions.copy()
self.in_function = True
self.local_vars = set()
self.current_param_positions = {}
for i, param in enumerate(cls.constructor.params):
self.current_param_positions[param.name] = i + 1
array_field_params = self._find_array_field_params(cls)
for i, param in enumerate(cls.constructor.params):
if param.name in array_field_params:
continue
if param.default is not None:
default_val = self.generate_expr(param.default)
self.emit(f'local {param.name}="${{{i + 1}:-{default_val}}}"')
......@@ -138,10 +147,25 @@ class ClassMixin:
self.in_class_method = False
self.in_function = old_in_function
self.local_vars = old_local_vars
self.current_param_positions = old_param_positions
self.indent_level -= 1
self.emit("}")
self.emit()
def _find_array_field_params(self, cls: ClassDecl) -> set:
"""Find parameters used for array field assignments."""
result = set()
if not cls.constructor:
return result
for stmt in cls.constructor.body.statements:
if isinstance(stmt, Assignment):
if isinstance(stmt.target, MemberAccess) and isinstance(stmt.target.object, ThisExpr):
field = stmt.target.member
field_type = self.class_field_types.get((cls.name, field))
if field_type == "array" and isinstance(stmt.value, Identifier):
result.add(stmt.value.name)
return result
def _generate_plain_method(self, cls: ClassDecl, method: FunctionDecl):
"""Generate a plain class method."""
self.emit(f"__ct_class_{cls.name}_{method.name} () {{")
......
......@@ -49,9 +49,12 @@ class CodeGenerator(StdlibMixin, AwkCodegenMixin, ExprMixin, StmtMixin,
self.dict_vars: Set[str] = set()
self.object_vars: Set[str] = set()
self.file_handle_vars: Set[str] = set()
self.nameref_vars: Set[str] = set() # vars that are namerefs to arrays/dicts
self.instance_vars: Dict[str, str] = {} # var_name -> class_name
self.class_field_types: Dict[tuple, str] = {}
self.local_vars: Set[str] = set()
self.current_param_positions: Dict[str, int] = {} # param_name -> position (1-based)
self.global_vars: Set[str] = {
'L_SRC', 'L_POS', 'L_LEN', 'L_LINE', 'L_COL', 'L_FILE',
'T_TYPES', 'T_VALUES', 'T_LINES', 'T_COUNT',
......
......@@ -170,6 +170,9 @@ class CseMixin:
if isinstance(expr, BoolLiteral):
return "true" if expr.value else "false"
if isinstance(expr, CallExpr) and id(expr) in mapping:
return f'[[ "${{{mapping[id(expr)]}}}" == "true" ]]'
return self.generate_condition(expr)
def generate_expr_with_precompute(self, expr: Expression, mapping: dict) -> str:
......
......@@ -110,6 +110,11 @@ class ExprMixin:
return match.group(0)
if '(' in content:
return self._handle_interpolation_call(content)
comparison_result = self._handle_interpolation_operator(content)
if comparison_result:
return comparison_result
if '.' in content:
parts = content.split('.', 1)
if parts[0] == 'this':
......@@ -132,11 +137,81 @@ class ExprMixin:
return value
def _handle_interpolation_operator(self, content: str) -> str:
"""Handle operators in string interpolation like {a == b}."""
comparison_ops = [('==', '=='), ('!=', '!='), ('<=', '-le'), ('>=', '-ge'), ('<', '-lt'), ('>', '-gt')]
for ct_op, bash_op in comparison_ops:
if ct_op in content:
parts = content.split(ct_op, 1)
if len(parts) == 2:
left = parts[0].strip()
right = parts[1].strip()
left_bash = self._interpolation_operand_to_bash(left)
right_bash = self._interpolation_operand_to_bash(right)
if ct_op in ('==', '!='):
return f'$([[ "{left_bash}" {bash_op} "{right_bash}" ]] && echo true || echo false)'
else:
return f'$([[ {left_bash} {bash_op} {right_bash} ]] && echo true || echo false)'
arith_ops = ['+', '-', '*', '/', '%']
for op in arith_ops:
if op in content and not content.startswith(op):
parts = content.split(op, 1)
if len(parts) == 2 and parts[0].strip() and parts[1].strip():
left = parts[0].strip()
right = parts[1].strip()
left_bash = self._interpolation_operand_to_bash(left)
right_bash = self._interpolation_operand_to_bash(right)
return f'$(({left_bash} {op} {right_bash}))'
return None
def _interpolation_operand_to_bash(self, operand: str) -> str:
"""Convert an operand in string interpolation to bash syntax."""
operand = operand.strip()
if operand.startswith('"') and operand.endswith('"'):
return operand[1:-1]
if operand.isdigit() or (operand.startswith('-') and operand[1:].isdigit()):
return operand
if '.' in operand:
parts = operand.split('.', 1)
if parts[0] == 'this':
return f'${{__CT_OBJ["$this.{parts[1]}"]}}'
else:
return f'${{__CT_OBJ["${{{parts[0]}}}.{parts[1]}"]:-}}'
return f'${{{operand}}}'
def _handle_interpolation_call(self, content: str) -> str:
"""Handle function/method calls in string interpolation."""
if '.' in content and not content.startswith('('):
dot_idx = content.find('.')
paren_idx = content.find('(')
before_paren = content[:paren_idx]
dots_count = before_paren.count('.')
if dots_count >= 2:
parts = before_paren.split('.')
obj = parts[0].strip()
field = parts[1].strip()
method = parts[2].strip()
args_csv = content[paren_idx + 1:-1]
args_list = self._parse_args_with_parens(args_csv) if args_csv else []
args_bash = ' '.join([self._convert_arg_to_bash(a) for a in args_list])
is_array_field = any(
getattr(self, 'class_field_types', {}).get((cls, field)) == "array"
for cls in getattr(self, 'classes', [])
)
is_dict_field = any(
getattr(self, 'class_field_types', {}).get((cls, field)) == "dict"
for cls in getattr(self, 'classes', [])
)
if is_array_field and method in ('len', 'get', 'shift', 'join', 'slice'):
return f'$(__ct_arr_{method} "${{{obj}}}_{field}" {args_bash})'.replace(' ', ' ').strip()
elif is_dict_field and method in ('get', 'has', 'keys', 'del'):
return f'$(__ct_dict_{method} "${{{obj}}}_{field}" {args_bash})'.replace(' ', ' ').strip()
dot_idx = content.find('.')
if dot_idx < paren_idx:
parts = content.split('.', 1)
obj = parts[0]
......@@ -146,15 +221,33 @@ class ExprMixin:
args_csv = rest[method_paren_idx + 1:-1]
args_list = self._parse_args_with_parens(args_csv) if args_csv else []
args_bash = ' '.join([self._convert_arg_to_bash(a) for a in args_list])
if obj == 'str':
is_var = (obj in getattr(self, 'array_vars', set()) or
obj in getattr(self, 'dict_vars', set()) or
obj in getattr(self, 'local_vars', set()) or
obj in getattr(self, 'global_vars', set()))
if obj == 'str' and not is_var:
return f'$(__ct_str_{method} {args_bash})'
elif obj == 'arr':
elif obj == 'arr' and not is_var:
return f'$(__ct_arr_{method} {args_bash})'
elif obj == 'args':
elif obj == 'args' and not is_var:
if method == 'count':
return '$(__ct_args_count)'
elif method == 'get':
return f'$(__ct_args_get {args_bash})'
if obj in getattr(self, 'array_vars', set()) and method in ('len', 'get', 'shift', 'join', 'slice', 'push', 'pop', 'set'):
arr_ref = f'${{{obj}}}' if obj in getattr(self, 'nameref_vars', set()) else obj
return f'$(__ct_arr_{method} "{arr_ref}" {args_bash})'.replace(' ', ' ').strip()
if obj in getattr(self, 'dict_vars', set()) and method in ('get', 'has', 'keys', 'set', 'del'):
dict_ref = f'${{{obj}}}' if obj in getattr(self, 'nameref_vars', set()) else obj
return f'$(__ct_dict_{method} "{dict_ref}" {args_bash})'.replace(' ', ' ').strip()
str_methods = ('len', 'upper', 'lower', 'trim', 'contains', 'starts', 'ends', 'index', 'replace', 'substr', 'split', 'charAt')
if method in str_methods:
if method == 'charAt':
return f'${{$(__ct_str_char_at "${{{obj}}}" {args_bash})%X}}'.replace(' ', ' ').strip()
return f'$(__ct_str_{method} "${{{obj}}}" {args_bash})'.replace(' ', ' ').strip()
return f'$(__ct_call_method "${{{obj}}}" "{method}" {args_bash})'
paren_idx = content.find('(')
func_name = content[:paren_idx].strip()
......@@ -332,12 +425,30 @@ class ExprMixin:
return None
def _generate_member_access(self, expr: MemberAccess) -> str:
field_name = expr.member
is_array_field = any(
self.class_field_types.get((cls, field_name)) == "array"
for cls in self.classes
)
is_dict_field = any(
self.class_field_types.get((cls, field_name)) == "dict"
for cls in self.classes
)
if isinstance(expr.object, ThisExpr):
if is_array_field or is_dict_field:
return f'${{this}}_{field_name}'
return f'${{__CT_OBJ["$this.{expr.member}"]}}'
obj = self.generate_expr(expr.object)
if obj.startswith('${') and obj.endswith('}'):
var_name = obj[2:-1]
if is_array_field or is_dict_field:
return f'${{{var_name}}}_{field_name}'
return f'${{__CT_OBJ["${{{var_name}}}.{expr.member}"]:-}}'
if is_array_field or is_dict_field:
return f'{obj}_{field_name}'
return f'${{__CT_OBJ["{obj}.{expr.member}"]:-}}'
def _generate_index_access(self, expr: IndexAccess) -> str:
......
......@@ -462,7 +462,7 @@ class StdlibMixin:
self.emit ("__ct_str_trim () { local s=\"$1\"; s=\"${s#\"${s%%[![:space:]]*}\"}\"; __CT_RET=\"${s%\"${s##*[![:space:]]}\"}\" ; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_upper () { __CT_RET=\"${1^^}\"; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_lower () { __CT_RET=\"${1,,}\"; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_char_at () { __CT_RET=\"${1:$2:1}\"; echo \"$__CT_RET\"; }")
self.emit ("__ct_str_char_at () { __CT_RET=\"${1:$2:1}\"; printf '%sX' \"$__CT_RET\"; }")
self.emit ("__ct_str_urlencode () { __CT_RET=$(printf '%s' \"$1\" | jq -sRr @uri); echo \"$__CT_RET\"; }")
self.emit ("__ct_str_concat () { __CT_RET=\"$1$2\"; echo \"$__CT_RET\"; }")
self.emit ()
......
......@@ -436,6 +436,29 @@ class StmtMixin:
self.emit("return 0")
return
if isinstance(stmt.value, DictLiteral):
self.emit('local __ct_ret_dict="__ct_dict_$RANDOM$RANDOM"')
self.emit('declare -gA "$__ct_ret_dict"')
for k, v in stmt.value.pairs:
key = self.generate_expr(k)
val = self.generate_expr(v)
self.emit(f'eval "$__ct_ret_dict[{key}]=\\"{val}\\""')
self.emit('__CT_RET="$__ct_ret_dict"')
self.emit('echo "$__CT_RET"')
self.emit("return 0")
return
if isinstance(stmt.value, ArrayLiteral):
self.emit('local __ct_ret_arr="__ct_arr_$RANDOM$RANDOM"')
self.emit('declare -ga "$__ct_ret_arr"')
elements = [self.generate_expr(e) for e in stmt.value.elements]
for i, elem in enumerate(elements):
self.emit(f'eval "$__ct_ret_arr[{i}]=\\"{elem}\\""')
self.emit('__CT_RET="$__ct_ret_arr"')
self.emit('echo "$__CT_RET"')
self.emit("return 0")
return
value = self.generate_expr(stmt.value)
self.emit(f'__CT_RET="{value}"')
self.emit('echo "$__CT_RET"')
......@@ -472,6 +495,23 @@ class StmtMixin:
def _generate_call_return(self, expr: CallExpr) -> bool:
"""Generate return for function/method calls."""
if isinstance(expr.callee, MemberAccess):
if isinstance(expr.callee.object, MemberAccess) and isinstance(expr.callee.object.object, ThisExpr):
field = expr.callee.object.member
method = expr.callee.member
args = [self.generate_expr(arg) for arg in expr.arguments]
args_str = " ".join([f'"{a}"' for a in args])
str_methods = {
"len": "__ct_str_len", "upper": "__ct_str_upper", "lower": "__ct_str_lower",
"trim": "__ct_str_trim", "contains": "__ct_str_contains", "starts": "__ct_str_starts",
"ends": "__ct_str_ends", "index": "__ct_str_index", "replace": "__ct_str_replace",
"substr": "__ct_str_substr", "split": "__ct_str_split", "charAt": "__ct_str_char_at",
}
if method in str_methods:
func_name = str_methods[method]
self.emit(f'{func_name} "${{__CT_OBJ["$this.{field}"]}}" {args_str} >/dev/null'.replace(' ', ' '))
self.emit('echo "$__CT_RET"')
self.emit("return 0")
return True
if isinstance(expr.callee.object, ThisExpr) and self.current_class:
method = expr.callee.member
args = [self.generate_expr(arg) for arg in expr.arguments]
......@@ -491,8 +531,16 @@ class StmtMixin:
return True
elif isinstance(expr.callee, Identifier):
func_name = expr.callee.name
if func_name in self.classes:
args = [self._generate_call_arg(arg) for arg in expr.arguments]
args_str = " ".join([f'"{a}"' for a in args])
self.emit(f'{func_name} {args_str}')
self.emit('__CT_RET="$__ct_last_instance"')
self.emit('echo "$__CT_RET"')
self.emit("return 0")
return True
if func_name in self.functions:
args = [self.generate_expr(arg) for arg in expr.arguments]
args = [self._generate_call_arg(arg) for arg in expr.arguments]
args_str = " ".join([f'"{a}"' for a in args])
self.emit(f'{func_name} {args_str} >/dev/null')
self.emit('echo "$__CT_RET"')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment