/* Expression tree creation and evaluation. */ #include #include #include #include #include "etree.h" #include "funcs.h" /*************************************************************************/ /*************************************************************************/ /* Skip over whitespace in a string. */ static void skipspace(const char **strptr) { while (isspace(**strptr)) (*strptr)++; } /*************************************************************************/ /* Convert C-like string escapes to their corresponding characters. Note * that this can only ever shrink the string. */ static void parse_escapes(char *str) { int len = strlen(str); while (*str) { if (*str == '\\') { memmove(str, str+1, len--); switch (*str) { case 'a': *str = '\a'; break; case 'b': *str = '\b'; break; case 'f': *str = '\f'; break; case 'n': *str = '\n'; break; case 'r': *str = '\r'; break; case 't': *str = '\t'; break; case 'v': *str = '\v'; break; } } str++; len--; } } /*************************************************************************/ /* Parse and return an operand from the given string. This might include * parsing an entire parenthesized subexpression. */ static ETree *parse_operand(const char **strptr) { const char *s; ETree *node; skipspace(strptr); s = *strptr; /* Check for a parenthesized subexpression. */ if (*s == '(') { /* Note a convenient fact: If this call ends on a close-paren, * it will always be the one matching the open-paren we just * found. */ node = malloc(sizeof(*node)); node->type = ET_SUBEXPR; node->u.subexprval.expr = etree_create(s+1, (char **)&s); if (*s++ != ')') { etree_free(node); return NULL; } *strptr = s; return node; } /* From here, it's either a single operand or nothing. */ node = malloc(sizeof(*node)); if (isdigit(*s) || ((*s == '+' || *s == '-') && isdigit(s[1]))) { /* Integer */ node->type = ET_INT; node->u.intval = strtol(s, (char **)strptr, 0); return node; } else if (isalpha(*s) || *s == '_') { /* Single-word string or function invocation */ char *buf; s++; /* Allow periods, for e.g. hostnames */ while (isalnum(*s) || *s == '_' || *s == '.') s++; buf = malloc(s - *strptr + 1); strncpy(buf, *strptr, s - *strptr); buf[s - *strptr] = 0; skipspace(&s); if (*s == '(') { int id, nargs, maxargs; ETree **args; s++; if ((id = func_id(buf)) < 0) { fprintf(stderr, "Unknown function `%s'\n", buf); free(node); return NULL; } node->type = ET_FUNC; node->u.funcval.func = id; args = etree_create_list(s, (char **)&s, &nargs); skipspace(&s); if (*s != ')') { fprintf(stderr, "Missing right-paren for function call\n"); etree_free_list(args, nargs); free(node); return NULL; } s++; maxargs = func_maxargs(id); if (nargs < func_minargs(id)) { fprintf(stderr, "Too few arguments to function `%s'\n", buf); etree_free_list(args, nargs); free(node); return NULL; } else if (nargs > maxargs) { fprintf(stderr, "Too many arguments to function `%s'\n", buf); etree_free_list(args, nargs); free(node); return NULL; } /* XXX need to check arg types--need etree_eval_type() */ /* Pad out the args array to the maximum number of arguments */ if (nargs < maxargs) { args = realloc(args, sizeof(ETree *) * maxargs); while (nargs < maxargs) args[nargs++] = NULL; } node->u.funcval.nargs = nargs; node->u.funcval.args = args; *strptr = s; return node; } else { /* Single-word string */ node->type = ET_STRING; node->u.stringval = buf; *strptr = s; return node; } } else if (*s == '"') { /* String */ node->type = ET_STRING; *strptr += 1; s++; while (*s && *s != '"') { if (*s == '\\' && s[1] != 0) s++; s++; } node->u.stringval = malloc(s - *strptr + 1); strncpy(node->u.stringval, *strptr, s - *strptr); node->u.stringval[s - *strptr] = 0; parse_escapes(node->u.stringval); if (*s == '"') s++; *strptr = s; return node; } else { /* No operand here */ free(node); return NULL; } } /*************************************************************************/ /* Parse and return an operator from the given string. Return NULL if we * don't see one. */ static ETree *parse_operator(const char **strptr) { const char *s; ETree *node; skipspace(strptr); s = *strptr; node = malloc(sizeof(*node)); node->type = ET_OPERATOR; node->u.exprval.arg1 = node->u.exprval.arg2 = NULL; if (*s == '+') { node->u.exprval.operator = OP_ADD; *strptr += 1; return node; } else if (*s == '-') { node->u.exprval.operator = OP_SUBTRACT; *strptr += 1; return node; } else if (*s == '*') { node->u.exprval.operator = OP_MULTIPLY; *strptr += 1; return node; } else if (*s == '/') { node->u.exprval.operator = OP_DIVIDE; *strptr += 1; return node; } else if (*s == '%') { node->u.exprval.operator = OP_MODULO; *strptr += 1; return node; } else if (*s == '<' && s[1] == '<') { node->u.exprval.operator = OP_LSHIFT; *strptr += 2; return node; } else if (*s == '>' && s[1] == '>') { node->u.exprval.operator = OP_RSHIFT; *strptr += 2; return node; } else if (*s == '&' && s[1] != '&') { node->u.exprval.operator = OP_BAND; *strptr += 1; return node; } else if (*s == '|' && s[1] != '|') { node->u.exprval.operator = OP_BOR; *strptr += 1; return node; } else if (*s == '^') { node->u.exprval.operator = OP_BXOR; *strptr += 1; return node; } else if (*s == '~') { node->u.exprval.operator = OP_BNOT; *strptr += 1; return node; } else if (*s == '=' && s[1] == '=') { node->u.exprval.operator = OP_EQ; *strptr += 2; return node; } else if (*s == '!' && s[1] == '=') { node->u.exprval.operator = OP_NE; *strptr += 2; return node; } else if (*s == '<' && s[1] != '=') { node->u.exprval.operator = OP_LT; *strptr += 1; return node; } else if (*s == '<' && s[1] == '=') { node->u.exprval.operator = OP_LE; *strptr += 2; return node; } else if (*s == '>' && s[1] != '=') { node->u.exprval.operator = OP_GT; *strptr += 1; return node; } else if (*s == '>' && s[1] == '=') { node->u.exprval.operator = OP_GE; *strptr += 2; return node; } else if (*s == '&' && s[1] == '&') { node->u.exprval.operator = OP_LAND; *strptr += 2; return node; } else if (*s == '|' && s[1] == '|') { node->u.exprval.operator = OP_LOR; *strptr += 2; return node; } else if (*s == '!') { node->u.exprval.operator = OP_LNOT; *strptr += 1; return node; } else { free(node); return NULL; } } /*************************************************************************/ /* Build an expression tree from the given input string. Return in `tail' * the first character after the end of the valid expression string. If * *tail == str, then no valid expression string could be found. */ ETree *etree_create(const char *str, char **tail) { ETree *node, *arg1 = NULL, *arg2, *top, *parent; /* Check for a unary operand (NOT). If we don't see it, try to parse * an operand, which might be a parenthesized subexpression. */ if (*str != '!' && *str != '~') { arg1 = parse_operand(&str); if (!arg1) return NULL; } /* Now try to get operator-operand pairs. */ while ((node = parse_operator(&str))) { int op; op = node->u.exprval.operator; /* Check for appropriate presence or absence of a first argument. */ if ((op == OP_LNOT || op == OP_BNOT) ? arg1 != NULL : arg1 == NULL) { if (arg1) etree_free(arg1); etree_free(node); return NULL; } /* Get a second argument (always required). */ arg2 = parse_operand(&str); if (!arg2) { if (arg1) etree_free(arg1); etree_free(node); return NULL; } /* Assign arguments to operator. If arg1 is an operator, we may * need to do a bit of tree rotation depending on operator * precedence. */ node->u.exprval.arg2 = arg2; /* Go through tree until we find either a terminal node or an * operator with equal or lower precedence. Assign that node to * arg1 of the current operator; if it has a parent, assign the * current node to the parent's arg2, and set the top node to the * previous top node. */ top = arg1; parent = NULL; while (arg1->type==ET_OPERATOR && arg1->u.exprval.operator/10 > op/10) { parent = arg1; arg1 = arg1->u.exprval.arg2; } node->u.exprval.arg1 = arg1; if (parent) { parent->u.exprval.arg2 = node; node = top; } /* Move this subexpression to arg1 for the next go-round. */ arg1 = node; } /* Return whatever we found. */ *tail = (char *)str; return arg1; } /*************************************************************************/ /* Create an array of ETrees from a string containing a comma-separated * list of expressions. This returns an array of ETree *'s, each one * representing one of the expressions. The length of the array will be * returned via the `len' pointer. */ ETree **etree_create_list(const char *str, char **tail, int *len) { ETree *node, **array = NULL; int num = 0, size = 0; while ((node = etree_create(str, tail)) || *str == ',') { if (num >= size) { size += 8; array = realloc(array, sizeof(ETree *) * size); } array[num++] = node; str = (const char *) *tail; if (*str != ',') break; str++; } *len = num; return array; } /*************************************************************************/ /* Free an expression tree. */ void etree_free(ETree *expr) { if (!expr) return; if (expr->type == ET_OPERATOR) { etree_free(expr->u.exprval.arg1); etree_free(expr->u.exprval.arg2); } else if (expr->type == ET_SUBEXPR) { etree_free(expr->u.subexprval.expr); } else if (expr->type == ET_FUNC) { etree_free_list(expr->u.funcval.args, expr->u.funcval.nargs); } else if (expr->type == ET_STRING) { if (expr->u.stringval) free(expr->u.stringval); } free(expr); } /*************************************************************************/ /* Free an array of expression trees. */ void etree_free_list(ETree **array, int size) { int i; for (i = 0; i < size; i++) etree_free(array[i]); if (size > 0) free(array); } /*************************************************************************/ /*************************************************************************/ /* Return whether a terminal expression node has a "nonzero" value. */ static int nonzero(ETree *node) { if (node->type == ET_INT) return node->u.intval != 0; else if (node->type == ET_STRING) return node->u.stringval[0] != 0; else return 1; } /*************************************************************************/ /* Evaluate an expression tree. * A note about the staticness of `res' and the recursiveness of this * function: While it is generally a bad idea to use a static buffer for a * recursive function, we make certain to copy all values we need * (including strings) before trying to write to the result buffer, so this * is in fact safe. */ ETree *etree_eval(ETree *expr) { static ETree res; static char *stringres = NULL; static int stringres_size = 0; ETree arg1, arg2; int op; /* Skip over subexpression nodes. */ while (expr && expr->type == ET_SUBEXPR) expr = expr->u.subexprval.expr; /* Ask a NULL question, get a NULL answer. */ if (!expr) return NULL; /* Evaluate terminal nodes. */ if (expr->type == ET_INT || expr->type == ET_STRING) { return expr; } else if (expr->type == ET_FUNC) { return func_eval(expr->u.funcval.func, expr->u.funcval.args); } else if (expr->type != ET_OPERATOR) { res.type = ET_INT; res.u.intval = 0; return &res; } /* At this point we know that expr->type == ET_OPERATOR. Evaluate * its argument(s). */ op = expr->u.exprval.operator; if (expr->u.exprval.arg1) arg1 = *etree_eval(expr->u.exprval.arg1); else arg1.type = -1; /* Handle boolean short-circuiting with logical AND and OR. */ if (arg1.type >= 0) { if (op == OP_LAND && !nonzero(&arg1)) { res.type = ET_INT; res.u.intval = 0; return &res; } else if (op == OP_LOR && nonzero(&arg1)) { res.type = ET_INT; res.u.intval = 1; return &res; } } /* Finish evaluation of arguments. */ if (arg1.type == ET_STRING) arg1.u.stringval = strdup(arg1.u.stringval); arg2 = *etree_eval(expr->u.exprval.arg2); if (arg2.type == ET_STRING) arg2.u.stringval = strdup(arg2.u.stringval); /* Handle logical AND, OR, and NOT specially. Other than addition, * these are the only operations which can accept parameters of * different types. Also, we have already done some processing for AND * and OR, so don't throw that work away. */ switch (op) { case OP_LNOT: res.type = ET_INT; res.u.intval = !nonzero(&arg2); if (arg1.type == ET_STRING) free(arg1.u.stringval); if (arg2.type == ET_STRING) free(arg2.u.stringval); return &res; case OP_LAND: case OP_LOR: res.type = ET_INT; res.u.intval = nonzero(&arg2); if (arg1.type == ET_STRING) free(arg1.u.stringval); if (arg2.type == ET_STRING) free(arg2.u.stringval); return &res; } /* If we have only one argument, the operator must be a binary NOT; if * it isn't, bail. */ if (arg1.type < 0 && op != OP_BNOT) { if (arg2.type == ET_STRING) free(arg2.u.stringval); res.type = ET_INT; res.u.intval = 0; return &res; } /* Handle the case of arguments with different types. */ if (arg1.type >= 0 && (arg1.type != arg2.type)) { switch (op) { case OP_ADD: { int len = 0; char intbuf[32]; /* Big enough for 64-bit ints */ if (arg1.type == ET_INT) len += sprintf(intbuf, "%d", arg1.u.intval); else if (arg1.type == ET_STRING) len += strlen(arg1.u.stringval); if (arg2.type == ET_INT) len += sprintf(intbuf, "%d", arg2.u.intval); else if (arg2.type == ET_STRING) len += strlen(arg2.u.stringval); len++; /* Count null byte at end */ if (len > stringres_size) { stringres_size = len; stringres = malloc(len); } *stringres = 0; if (arg1.type == ET_INT) { sprintf(intbuf, "%d", arg1.u.intval); strcat(stringres, intbuf); } else if (arg1.type == ET_STRING) { strcat(stringres, arg1.u.stringval); free(arg1.u.stringval); } if (arg2.type == ET_INT) { sprintf(intbuf, "%d", arg2.u.intval); strcat(stringres, intbuf); } else if (arg2.type == ET_STRING) { strcat(stringres, arg2.u.stringval); free(arg2.u.stringval); } res.type = ET_STRING; res.u.stringval = stringres; return &res; } default: if (arg1.type == ET_STRING) free(arg1.u.stringval); if (arg2.type == ET_STRING) free(arg2.u.stringval); res.type = ET_INT; res.u.intval = 0; return &res; } } /* Handle the case of string arguments. */ if (arg2.type == ET_STRING) { switch (op) { case OP_ADD: { int len = strlen(arg1.u.stringval)+strlen(arg2.u.stringval)+1; if (len > stringres_size) { stringres_size = len; stringres = realloc(stringres, len); } sprintf(stringres, "%s%s", arg1.u.stringval, arg2.u.stringval); free(arg1.u.stringval); free(arg2.u.stringval); res.type = ET_STRING; res.u.stringval = stringres; return &res; } case OP_EQ: res.type = ET_INT; res.u.intval = (strcmp(arg1.u.stringval,arg2.u.stringval) == 0); free(arg1.u.stringval); free(arg2.u.stringval); return &res; case OP_NE: res.type = ET_INT; res.u.intval = (strcmp(arg1.u.stringval,arg2.u.stringval) != 0); free(arg1.u.stringval); free(arg2.u.stringval); return &res; case OP_LT: res.type = ET_INT; res.u.intval = (strcmp(arg1.u.stringval,arg2.u.stringval) < 0); free(arg1.u.stringval); free(arg2.u.stringval); return &res; case OP_LE: res.type = ET_INT; res.u.intval = (strcmp(arg1.u.stringval,arg2.u.stringval) <= 0); free(arg1.u.stringval); free(arg2.u.stringval); return &res; case OP_GT: res.type = ET_INT; res.u.intval = (strcmp(arg1.u.stringval,arg2.u.stringval) > 0); free(arg1.u.stringval); free(arg2.u.stringval); return &res; case OP_GE: res.type = ET_INT; res.u.intval = (strcmp(arg1.u.stringval,arg2.u.stringval) >= 0); free(arg1.u.stringval); free(arg2.u.stringval); return &res; default: if (stringres_size == 0) { stringres = malloc(256); stringres_size = 256; } res.u.stringval = stringres; return &res; } } /* Handle the remaining case (integer arguments). */ else { res.type = ET_INT; switch (op) { case OP_ADD: res.u.intval = arg1.u.intval + arg2.u.intval; return &res; case OP_SUBTRACT: res.u.intval = arg1.u.intval - arg2.u.intval; return &res; case OP_MULTIPLY: res.u.intval = arg1.u.intval * arg2.u.intval; return &res; case OP_DIVIDE: res.u.intval = arg1.u.intval / arg2.u.intval; return &res; case OP_MODULO: res.u.intval = arg1.u.intval % arg2.u.intval; return &res; case OP_LSHIFT: res.u.intval = arg1.u.intval << arg2.u.intval; return &res; case OP_RSHIFT: res.u.intval = arg1.u.intval >> arg2.u.intval; return &res; case OP_BAND: res.u.intval = arg1.u.intval & arg2.u.intval; return &res; case OP_BOR: res.u.intval = arg1.u.intval | arg2.u.intval; return &res; case OP_BXOR: res.u.intval = arg1.u.intval ^ arg2.u.intval; return &res; case OP_BNOT: res.u.intval = ~ arg2.u.intval; return &res; case OP_EQ: res.u.intval = arg1.u.intval == arg2.u.intval; return &res; case OP_NE: res.u.intval = arg1.u.intval != arg2.u.intval; return &res; case OP_LT: res.u.intval = arg1.u.intval < arg2.u.intval; return &res; case OP_LE: res.u.intval = arg1.u.intval <= arg2.u.intval; return &res; case OP_GT: res.u.intval = arg1.u.intval > arg2.u.intval; return &res; case OP_GE: res.u.intval = arg1.u.intval >= arg2.u.intval; return &res; default: res.u.intval = 0; return &res; } } } /*************************************************************************/ /* Evaluate an expression tree as an integer. */ int etree_eval_i(ETree *expr) { ETree *val = etree_eval(expr); if (!val) return 0; else if (val->type == ET_INT) return val->u.intval; else if (val->type == ET_STRING) return strtol(val->u.stringval, NULL, 0); else return 0; } /*************************************************************************/ /* Evaluate an expression tree as a string. */ char *etree_eval_s(ETree *expr) { static char buf[16]; ETree *val = etree_eval(expr); if (!val) { *buf = 0; return buf; } else if (val->type == ET_STRING) { return val->u.stringval; } else if (val->type == ET_INT) { snprintf(buf, sizeof(buf), "%d", val->u.intval); return buf; } else { *buf = 0; return buf; } } /*************************************************************************/ /*************************************************************************/ /* Print an expression tree. */ static void do_print(ETree *expr, int depth) { int i; /* Skip over parentheses */ while (expr && expr->type == ET_SUBEXPR) expr = expr->u.subexprval.expr; /* Indent for depth */ for (i = 0; i < depth; i++) printf(" "); /* Catch null expressions */ if (!expr) { printf("(nil)\n"); return; } /* Print the thing */ switch (expr->type) { case ET_INT: printf("%d\n", expr->u.intval); break; case ET_STRING: printf("\"%s\"\n", expr->u.stringval); break; case ET_FUNC: printf("func:%d (%d)\n", expr->u.funcval.func, expr->u.funcval.nargs); for (i = 0; i < expr->u.funcval.nargs; i++) do_print(expr->u.funcval.args[i], depth+1); break; case ET_OPERATOR: switch (expr->u.exprval.operator) { case OP_ADD: printf("+"); break; case OP_SUBTRACT: printf("-"); break; case OP_MULTIPLY: printf("*"); break; case OP_DIVIDE: printf("/"); break; case OP_MODULO: printf("%%"); break; case OP_LSHIFT: printf("<<"); break; case OP_RSHIFT: printf(">>"); break; case OP_BAND: printf("^"); break; case OP_BOR: printf("|"); break; case OP_BXOR: printf("^"); break; case OP_BNOT: printf("~"); break; case OP_EQ: printf("=="); break; case OP_NE: printf("!="); break; case OP_LT: printf("<"); break; case OP_LE: printf("<="); break; case OP_GT: printf(">"); break; case OP_GE: printf(">="); break; case OP_LAND: printf("&&"); break; case OP_LOR: printf("||"); break; case OP_LNOT: printf("!"); break; default: printf("op:???"); break; } putchar('\n'); if (expr->u.exprval.operator != OP_LNOT && expr->u.exprval.operator != OP_BNOT) { do_print(expr->u.exprval.arg1, depth+1); } do_print(expr->u.exprval.arg2, depth+1); break; default: printf("???\n"); break; } } /*************************************************************************/ /* External interface to do_print(). */ void etree_print(ETree *expr) { do_print(expr, 0); } /*************************************************************************/