Source code for pyreason.scripts.utils.rule_parser

import numba
import numpy as np

import pyreason.scripts.numba_wrapper.numba_types.rule_type as rule
import pyreason.scripts.numba_wrapper.numba_types.label_type as label
import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval


[docs] def parse_rule(rule_text: str, name: str, custom_thresholds: list, infer_edges: bool = False, set_static: bool = False, immediate_rule: bool = False) -> rule.Rule: # First remove all spaces from line r = rule_text.replace(' ', '') # Separate into head and body head, body = r.split('<-') # Extract delta_t of rule if it exists else set it to 0 t = '' is_digit = True while is_digit: if body[0].isdigit(): t += body[0] body = body[1:] else: is_digit = False if t == '': t = 0 else: t = int(t) # Raw parsing steps # 1. Remove whitespaces # 2. replace ) by )) and ] by ]] so that we can split without damaging the string # 3. Split with ), and then for each element of list, split with ], and add to new list # 4. Then replace ]] with ] and )) with ) in for loop # 5. Add :[1,1] to the end of each element if a bound is not specified # 6. Then split each element with : # 7. Transform bound strings into pr.intervals # 2 body = body.replace(')', '))') body = body.replace(']', ']]') # 3 body = body.split('),') split_body = [] for b in body: split_body.extend(b.split('],')) # 4 for i in range(len(split_body)): split_body[i] = split_body[i].replace('))', ')') split_body[i] = split_body[i].replace(']]', ']') # 5 for i in range(len(split_body)): if split_body[i][-1] != ']': split_body[i] += ':[1,1]' # 6 body_clauses = [] body_bounds = [] for b in split_body: clause, bound = b.split(':') body_clauses.append(clause) body_bounds.append(bound) # 7 for i in range(len(body_bounds)): bound = body_bounds[i] l, u = _str_bound_to_bound(bound) body_bounds[i] = [l, u] # Find the target predicate and bounds and annotation function if any. # Possible heads: # pred(x) : [x,y] # pred(x) : f # pred(x) # This means there is no bound or annotation function specified if head[-1] == ')': head += ':[1,1]' head, head_bound = head.split(':') # Check if we have a bound or annotation function if _is_bound(head_bound): target_bound = list(_str_bound_to_bound(head_bound)) target_bound = interval.closed(*target_bound) ann_fn = '' else: target_bound = interval.closed(0, 1) ann_fn = head_bound idx = head.find('(') target = head[:idx] target = label.Label(target) # Variable(s) in the head of the rule end_idx = head.find(')') head_variables = head[idx + 1:end_idx].split(',') # Assign type of rule rule_type = 'node' if len(head_variables) == 1 else 'edge' # Get the variables in the body # If there's an operator in the body then discard anything that comes after the operator, but keep the variables body_predicates = [] body_variables = [] for clause in body_clauses: start_idx = clause.find('(') end_idx = clause.find(')') body_predicates.append(clause[:start_idx]) # Add body variables depending on whether there's an operator or not variables = clause[start_idx+1:end_idx].split(',') start_idx = clause.find('(', start_idx+1) end_idx = clause.find(')', end_idx+1) if start_idx != -1 and end_idx != -1: variables += clause[start_idx+1:end_idx].split(',') body_variables.append(variables) # Change infer edge parameter if it's a node rule if rule_type == 'node': infer_edges = False # Replace the variables in the body with source/target if they match the variables in the head # If infer_edges is true, then we consider all rules to be node rules, we infer the 2nd variable of the target predicate from the rule body # Else we consider the rule to be an edge rule and replace variables with source/target # Node rules with possibility of adding edges if infer_edges or len(head_variables) == 1: head_source_variable = head_variables[0] for i in range(len(body_variables)): for j in range(len(body_variables[i])): if body_variables[i][j] == head_source_variable: body_variables[i][j] = '__target' # Edge rule, no edges to be added elif len(head_variables) == 2: for i in range(len(body_variables)): for j in range(len(body_variables[i])): if body_variables[i][j] == head_variables[0]: body_variables[i][j] = '__source' elif body_variables[i][j] == head_variables[1]: body_variables[i][j] = '__target' # Start setting up clauses # clauses = [c1, c2, c3, c4] # thresholds = [t1, t2, t3, t4] # Array of thresholds to keep track of for each neighbor criterion. Form [(comparison, (number/percent, total/available), thresh)] thresholds = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, numba.types.UniTuple(numba.types.string, 2), numba.types.float64))) # Array to store clauses for nodes: node/edge, [subset]/[subset1, subset2], label, interval, operator clauses = numba.typed.List.empty_list(numba.types.Tuple((numba.types.string, label.label_type, numba.types.ListType(numba.types.string), interval.interval_type, numba.types.string))) # gather count of clauses for threshold validation num_clauses = len(body_clauses) if custom_thresholds and (len(custom_thresholds) != num_clauses): raise Exception('The length of custom thresholds {} is not equal to number of clauses {}' .format(len(custom_thresholds), num_clauses)) # If no custom thresholds provided, use defaults # otherwise loop through user-defined thresholds and convert to numba compatible format if not custom_thresholds: for _ in range(num_clauses): thresholds.append(('greater_equal', ('number', 'total'), 1.0)) else: for threshold in custom_thresholds: thresholds.append(threshold.to_tuple()) # # Loop though clauses for body_clause, predicate, variables, bounds in zip(body_clauses, body_predicates, body_variables, body_bounds): # Neigh criteria clause_type = 'node' if len(variables) == 1 else 'edge' op = _get_operator_from_clause(body_clause) if op: clause_type = 'comparison' subset = numba.typed.List(variables) l = label.Label(predicate) bnd = interval.closed(bounds[0], bounds[1]) clauses.append((clause_type, l, subset, bnd, op)) # Assert that there are two variables in the head of the rule if we infer edges # Add edges between head variables if necessary if infer_edges: var = '__target' if head_variables[0] == head_variables[1] else head_variables[1] edges = ('__target', var, target) else: edges = ('', '', label.Label('')) weights = np.ones(len(body_predicates), dtype=np.float64) weights = np.append(weights, 0) r = rule.Rule(name, rule_type, target, numba.types.uint16(t), clauses, target_bound, thresholds, ann_fn, weights, edges, set_static, immediate_rule) return r
def _str_bound_to_bound(str_bound): str_bound = str_bound.replace('[', '') str_bound = str_bound.replace(']', '') l, u = str_bound.split(',') return float(l), float(u) def _is_bound(str_bound): str_bound = str_bound.replace('[', '') str_bound = str_bound.replace(']', '') try: l, u = str_bound.split(',') l = l.replace('.', '') u = u.replace('.', '') if l.isdigit() and u.isdigit(): result = True else: result = False except: result = False return result def _get_operator_from_clause(clause): operators = ['<=', '>=', '<', '>', '==', '!='] for op in operators: if op in clause: return op # No operator found in clause return ''