Commit a8f275f6 authored by Roman Alifanov's avatar Roman Alifanov

Sync AWK codegen with bash codegen, fix dict.field assignment

AWK codegen: - Add string interpolation: "Hello {name}" -> ("Hello " name "!") - Add env.VAR support via ENVIRON["VAR"] - Add time.now() via systime() - Add str.chr() via sprintf("%c", n) - Add assert_eq() for equality assertions - Add print() as statement - Add method calls as statements (arr.push, dict.set) Bash codegen: - Fix dict.field = value for dict variables (was using wrong storage) - Fix dict.field read to use dict["field"] syntax
parent 5e6d896e
......@@ -8,7 +8,7 @@ from .ast_nodes import (
BoolLiteral, NilLiteral, ArrayLiteral, DictLiteral, BinaryOp,
UnaryOp, CallExpr, IndexAccess, MemberAccess
)
from .methods import get_awk_builtin, generate_awk, MATH_METHODS
from .methods import get_awk_builtin, generate_awk, MATH_METHODS, TIME_METHODS
from .constants import RET_VAR
......@@ -347,12 +347,36 @@ class AwkCodegenMixin:
elif isinstance (stmt, ExpressionStmt):
if isinstance (stmt.expression, CallExpr) and isinstance (stmt.expression.callee, Identifier):
if stmt.expression.callee.name == "assert":
args = stmt.expression.arguments
func_name = stmt.expression.callee.name
args = stmt.expression.arguments
if func_name == "assert":
cond = self._awk_cond (args[0]) if args else "1"
msg = self._awk_expr (args[1]) if len (args) >= 2 else '"Assertion failed"'
emit (f"if (!({cond})) {{ print {msg} > \"/dev/stderr\"; exit 1 }}")
return
if func_name == "assert_eq":
expected = self._awk_expr (args[0]) if args else '""'
actual = self._awk_expr (args[1]) if len (args) >= 2 else '""'
msg = self._awk_expr (args[2]) if len (args) >= 3 else '"Values not equal"'
emit (f"if ({expected} != {actual}) {{ print {msg} > \"/dev/stderr\"; exit 1 }}")
return
if func_name == "print":
awk_args = [self._awk_expr (a) for a in args]
emit (f"print {', '.join (awk_args)}" if awk_args else "print")
return
if isinstance (stmt.expression, CallExpr) and isinstance (stmt.expression.callee, MemberAccess):
if isinstance (stmt.expression.callee.object, Identifier):
obj_name = stmt.expression.callee.object.name
method = stmt.expression.callee.member
args = stmt.expression.arguments
var_types = getattr (self, '_awk_var_types', {})
var_type = var_types.get (obj_name, "string")
type_name = {"array": "array", "dict": "dict"}.get (var_type, "string")
awk_args = [self._awk_expr (a) for a in args]
awk_code = generate_awk (type_name, method, obj_name, awk_args)
if awk_code:
emit (awk_code)
return
expr = self._awk_expr (stmt.expression)
if expr:
emit (expr)
......@@ -409,6 +433,8 @@ class AwkCodegenMixin:
idx = self._awk_expr (expr.index)
return f"{obj}[{idx}]"
if isinstance (expr, MemberAccess):
if isinstance (expr.object, Identifier) and expr.object.name == "env":
return f'ENVIRON["{expr.member}"]'
obj = self._awk_expr (expr.object)
return f'{obj}["{expr.member}"]'
return self._awk_expr (expr)
......@@ -425,6 +451,8 @@ class AwkCodegenMixin:
return str (expr.value)
if isinstance (expr, StringLiteral):
if expr.has_interpolation:
return self._awk_interpolate_string (expr.value)
value = expr.value
value = value.replace ('\\', '\\\\')
value = value.replace ('\n', '\\n')
......@@ -483,6 +511,12 @@ class AwkCodegenMixin:
awk_args = [self._awk_expr(a) for a in args]
return math_method.awk_builtin(awk_args)
if ns == "time" and method in TIME_METHODS:
time_method = TIME_METHODS[method]
if time_method.awk_builtin:
awk_args = [self._awk_expr(a) for a in args]
return time_method.awk_builtin(awk_args)
var_types = getattr (self, '_awk_var_types', {})
var_type = var_types.get (ns, "string")
type_name = {"array": "array", "dict": "dict"}.get(var_type, "string")
......@@ -509,6 +543,8 @@ class AwkCodegenMixin:
return f"{obj}[{idx}]"
if isinstance (expr, MemberAccess):
if isinstance (expr.object, Identifier) and expr.object.name == "env":
return f'ENVIRON["{expr.member}"]'
obj = self._awk_expr (expr.object)
return f'{obj}["{expr.member}"]'
......@@ -584,6 +620,80 @@ class AwkCodegenMixin:
return '""'
return self._awk_expr (body)
def _awk_interpolate_string (self, value: str) -> str:
"""Convert string with {var} interpolation to AWK concatenation."""
parts = []
i = 0
while i < len (value):
if value[i] == '{' and (i == 0 or value[i - 1] != '\\'):
end = value.find ('}', i)
if end != -1:
if i > 0 and parts == []:
parts.append (f'"{self._awk_escape_str (value[:i])}"')
elif i > 0:
text_before = value[parts[-1][1] if isinstance (parts[-1], tuple) else 0:i]
if text_before and not text_before.startswith ('{'):
pass
var_name = value[i + 1:end]
if '.' in var_name:
obj, member = var_name.split ('.', 1)
if obj == 'env':
parts.append (f'ENVIRON["{member}"]')
else:
parts.append (f'{obj}["{member}"]')
else:
parts.append (var_name)
i = end + 1
continue
i += 1
if not parts:
escaped = self._awk_escape_str (value)
return f'"{escaped}"'
result_parts = []
last_end = 0
i = 0
while i < len (value):
if value[i] == '{' and (i == 0 or value[i - 1] != '\\'):
end = value.find ('}', i)
if end != -1:
if i > last_end:
text = value[last_end:i]
if text:
result_parts.append (f'"{self._awk_escape_str (text)}"')
var_name = value[i + 1:end]
if '.' in var_name:
obj, member = var_name.split ('.', 1)
if obj == 'env':
result_parts.append (f'ENVIRON["{member}"]')
else:
result_parts.append (f'{obj}["{member}"]')
else:
result_parts.append (var_name)
last_end = end + 1
i = end + 1
continue
i += 1
if last_end < len (value):
text = value[last_end:]
if text:
result_parts.append (f'"{self._awk_escape_str (text)}"')
if len (result_parts) == 1:
return result_parts[0]
return '(' + ' '.join (result_parts) + ')'
def _awk_escape_str (self, s: str) -> str:
"""Escape string for AWK double quotes."""
s = s.replace ('\\', '\\\\')
s = s.replace ('\n', '\\n')
s = s.replace ('\t', '\\t')
s = s.replace ('"', '\\"')
s = s.replace ("'", "\\047")
return s
def _awk_emit_validation (self, params, emit):
"""Generate AWK validation code for @validate decorator."""
validate = self._awk_validate
......
......@@ -178,6 +178,9 @@ class DispatchMixin:
if obj_name in self.object_vars:
self._generate_obj_field_assignment(stmt, obj_name)
return
if obj_name in self.dict_vars:
self._generate_dict_field_assignment(stmt, obj_name)
return
target = self.generate_lvalue(stmt.target)
......@@ -394,6 +397,23 @@ class DispatchMixin:
else:
self.emit(f'__CT_OBJ["${{{obj_name}}}.{field}"]="{value}"')
def _generate_dict_field_assignment(self, stmt: Assignment, dict_name: str):
"""Generate dict.field = value assignment for dict variables."""
field = stmt.target.member
value = self.generate_expr(stmt.value)
param_map = getattr(self, 'param_name_map', {})
mapped_name = param_map.get(dict_name, dict_name)
if stmt.operator == "=":
self.emit(f'{mapped_name}["{field}"]="{value}"')
elif stmt.operator == "+=":
self.emit(f'{mapped_name}["{field}"]="$(( ${{{mapped_name}["{field}"]}} + {value} ))"')
elif stmt.operator == "-=":
self.emit(f'{mapped_name}["{field}"]="$(( ${{{mapped_name}["{field}"]}} - {value} ))"')
elif stmt.operator == "..=":
self.emit(f'{mapped_name}["{field}"]="${{{mapped_name}["{field}"]}}{value}"')
else:
self.emit(f'{mapped_name}["{field}"]="{value}"')
def _generate_method_call_assignment(self, stmt: Assignment, target: str) -> bool:
"""Generate method call assignment. Returns True if handled."""
callee = stmt.value.callee
......
......@@ -437,6 +437,13 @@ class ExprMixin:
if isinstance(expr.object, Identifier) and expr.object.name == "env":
return f'${{{expr.member}}}'
if isinstance(expr.object, Identifier):
obj_name = expr.object.name
param_map = getattr(self, 'param_name_map', {})
mapped_name = param_map.get(obj_name, obj_name)
if mapped_name in getattr(self, 'dict_vars', set()):
return f'${{{mapped_name}["{expr.member}"]}}'
field_name = expr.member
is_array_field = any(
self.class_field_types.get((cls, field_name)) == "array"
......
......@@ -104,4 +104,5 @@ class StringMethods:
name="chr",
bash_func="__ct_str_chr",
bash_impl="printf -v __CT_RET '%b' \"\\\\x$(printf '%02x' \"$1\")\"; echo \"$__CT_RET\"",
awk_gen=lambda obj, args: f'sprintf("%c", {args[0]})',
)
......@@ -2,5 +2,14 @@ from .base import Method
class TimeMethods:
now = Method(name="now", bash_func="__ct_time_now", bash_impl='date +%s')
ms = Method(name="ms", bash_func="__ct_time_ms", bash_impl='date +%s%3N')
now = Method(
name="now",
bash_func="__ct_time_now",
bash_impl='date +%s',
awk_builtin=lambda a: "systime()",
)
ms = Method(
name="ms",
bash_func="__ct_time_ms",
bash_impl='date +%s%3N',
)
......@@ -1962,3 +1962,149 @@ print(c.value)
''')
assert code == 0
assert "20" in stdout
class TestAwkSync:
def test_awk_string_interpolation(self):
code, stdout, stderr = run_ct('''
@awk
func greet(name) {
return "Hello, {name}!"
}
print(greet("World"))
''')
assert code == 0
assert "Hello, World!" in stdout
def test_awk_env_var(self):
code, stdout, stderr = run_ct('''
@awk
func get_home() {
return env.HOME
}
result = get_home()
print(result)
''')
assert code == 0
assert "/" in stdout
def test_awk_time_now(self):
code, stdout, stderr = run_ct('''
@awk
func get_time() {
return time.now()
}
ts = get_time()
print(ts)
''')
assert code == 0
assert len(stdout.strip()) >= 10
def test_awk_assert_eq(self):
code, stdout, stderr = run_ct('''
@awk
func test_eq() {
x = 10
assert_eq(10, x, "should be 10")
return "passed"
}
print(test_eq())
''')
assert code == 0
assert "passed" in stdout
def test_awk_assert_eq_fails(self):
code, stdout, stderr = run_ct('''
@awk
func test_eq() {
x = 5
assert_eq(10, x, "not equal")
return "passed"
}
test_eq()
''')
assert code != 0
assert "not equal" in stdout or "not equal" in stderr
def test_awk_array_push_statement(self):
code, stdout, stderr = run_ct('''
@awk
func test_push() {
arr = []
arr.push(1)
arr.push(2)
arr.push(3)
return arr[1] .. "-" .. arr[2] .. "-" .. arr[3]
}
print(test_push())
''')
assert code == 0
assert "1-2-3" in stdout
def test_awk_dict_set_statement(self):
code, stdout, stderr = run_ct('''
@awk
func test_dict() {
d = {}
d.set("key", "value")
return d.get("key")
}
print(test_dict())
''')
assert code == 0
assert "value" in stdout
def test_awk_print_statement(self):
code, stdout, stderr = run_ct('''
@awk
func test_print() {
print("inside awk")
return "done"
}
result = test_print()
print(result)
''')
assert code == 0
assert "inside awk" in stdout
assert "done" in stdout
class TestDictFieldAssignment:
def test_dict_field_assign_simple(self):
code, stdout, stderr = run_ct('''
func test() {
obj = {}
obj.name = "Alice"
return obj.name
}
print(test())
''')
assert code == 0
assert "Alice" in stdout
def test_dict_field_assign_multiple(self):
code, stdout, stderr = run_ct('''
func test() {
obj = {}
obj.name = "Bob"
obj.age = 30
return obj.name .. " is " .. obj.age
}
print(test())
''')
assert code == 0
assert "Bob is 30" in stdout
def test_dict_field_assign_compound(self):
code, stdout, stderr = run_ct('''
func test() {
obj = {}
obj.count = 10
obj.count += 5
return obj.count
}
print(test())
''')
assert code == 0
assert "15" in stdout
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment