diff --git a/asteroid/frontend.py b/asteroid/frontend.py index 9f130ba8..b95fd20e 100644 --- a/asteroid/frontend.py +++ b/asteroid/frontend.py @@ -69,7 +69,7 @@ def dbg_print(string): 'TRY', 'WHILE', #'WITH', - } | primary_lookahead + } | exp_lookahead ########################################################################################### class Parser: @@ -105,6 +105,7 @@ def stmt_list(self): sl = [] while self.lexer.peek().type in stmt_lookahead: sl += [('lineinfo', state.lineinfo)] + sl += [('clear-ret-val',)] sl += [self.stmt()] return ('list', sl) @@ -131,7 +132,7 @@ def stmt_list(self): # | TRY stmt_list (CATCH pattern DO stmt_list)+ END # | THROW exp '.'? # | function_def - # | call_or_index '.'? + # | exp '.'? def stmt(self): dbg_print("parsing STMT") tt = self.lexer.peek().type # tt - Token Type @@ -391,10 +392,10 @@ def stmt(self): self.lexer.match_optional('DOT') return ('throw', e) - elif tt in primary_lookahead: - v = self.call_or_index() + elif tt in exp_lookahead: + v = self.exp() self.lexer.match_optional('DOT') - return v + return ('set-ret-val', v) else: raise SyntaxError("syntax error at '{}'" diff --git a/asteroid/grammar.txt b/asteroid/grammar.txt index 8626eb6f..04358b89 100644 --- a/asteroid/grammar.txt +++ b/asteroid/grammar.txt @@ -37,7 +37,7 @@ | BREAK '.'? | RETURN exp? '.'? | function_def - | call_or_index '.'? + | exp '.'? function_def : FUNCTION ID body_defs END diff --git a/asteroid/test-suites/regression-tests/test136.ast b/asteroid/test-suites/regression-tests/test136.ast new file mode 100644 index 00000000..05212d1f --- /dev/null +++ b/asteroid/test-suites/regression-tests/test136.ast @@ -0,0 +1,3 @@ +-- test expressions computing return values +let x = (lambda with i do i+1) 1. +assert(x is 2). diff --git a/asteroid/test-suites/regression-tests/test137.ast b/asteroid/test-suites/regression-tests/test137.ast new file mode 100644 index 00000000..58258031 --- /dev/null +++ b/asteroid/test-suites/regression-tests/test137.ast @@ -0,0 +1,18 @@ +-- test expressions computing return values + +load system io. + +function f1 with () do + 3+3. + io @println "hello". + 5. +end + +function f2 with () do + 3+3. + io @println "hello". +end + +assert(f1() is 5). +assert(f2() is none). + diff --git a/asteroid/walk.py b/asteroid/walk.py index b05fe42a..8bbb81e1 100644 --- a/asteroid/walk.py +++ b/asteroid/walk.py @@ -26,6 +26,12 @@ ######################################################################### __retval__ = None # return value register for escaped code +######################################################################### +# return values for function computed by the last expression executed +# in the context of a function. note that we consider global code +# to be part of the 'top-level' function +function_return_value = [None] + ########################################################################################### # check if the two type tags match def match(tag1, tag2): @@ -883,12 +889,19 @@ def handle_call(obj_ref, fval, actual_val_args, fname): # function calls transfer control - save our caller's lineinfo old_lineinfo = state.lineinfo + global function_return_value try: + function_return_value.append(None) walk(stmts) + val = function_return_value.pop() + if val: + return_value = val + else: + return_value = ('none', None) except ReturnValue as val: + # we got here because a return statement threw a return object + function_return_value.pop() return_value = val.value - else: - return_value = ('none', None) # need that in case function has no return statement # coming back from a function call - restore caller's env state.lineinfo = old_lineinfo @@ -1281,7 +1294,7 @@ def apply_exp(node): return handle_builtins(node) # handle function application - # retrive the function name from the AST + # retrieve the function name from the AST if f[0] in ['function-exp','apply']: # cannot use the function expression as a name, # could be a very complex computation. the apply @@ -1613,6 +1626,19 @@ def constraint_exp(node): "constraint pattern: '{}' cannot be used as a constructor." .format(term2string(node))) +######################################################################### +def set_ret_val(node): + + (SET_RET_VAL, exp) = node + assert_match(SET_RET_VAL,'set-ret-val') + + global function_return_value + val = walk(exp) + function_return_value.pop() + function_return_value.append(val) + + return + ######################################################################### # walk ######################################################################### @@ -1620,7 +1646,13 @@ def walk(node): # node format: (TYPE, [child1[, child2[, ...]]]) type = node[0] - if type in dispatch_dict: + if type == 'clear-ret-val': + # implemented here instead of dictionary for efficiency reasons + global function_return_value + function_return_value.pop() + function_return_value.append(None) + return + elif type in dispatch_dict: node_function = dispatch_dict[type] return node_function(node) else: @@ -1630,6 +1662,7 @@ def walk(node): dispatch_dict = { # statements - statements do not produce return values 'lineinfo' : process_lineinfo, + 'set-ret-val' : set_ret_val, 'noop' : lambda node : None, 'assert' : assert_stmt, 'unify' : unify_stmt, @@ -1651,7 +1684,6 @@ def walk(node): 'head-tail' : head_tail_exp, 'raw-to-list' : lambda node : walk(('to-list', node[1], node[2], node[3])), 'raw-head-tail' : lambda node : walk(('head-tail', node[1], node[2])), - 'seq' : lambda node : ('seq', walk(node[1]), walk(node[2])), 'none' : lambda node : node, 'nil' : lambda node : node, 'function-exp' : function_exp,