Source code for pyreason.scripts.diffuse

import io
import time
import cProfile
import pstats
import sys
import memory_profiler as mp

import pyreason.scripts.numba_wrapper.numba_types.interval_type as interval
from pyreason.scripts.program.program import Program
import pyreason.scripts.utils.yaml_parser as yaml_parser
from pyreason.scripts.utils.graphml_parser import GraphmlParser
from pyreason.scripts.utils.filter import Filter
from pyreason.scripts.utils.output import Output
from pyreason.scripts.args import argparser



[docs] def main(args): timestamp = time.strftime('%Y%m%d-%H%M%S') if args.output_to_file: sys.stdout = open(f"./output/{args.output_file_name}_{timestamp}.txt", "w") # Initialize parsers graphml_parser = GraphmlParser() start = time.time() graph = graphml_parser.parse_graph(args.graph_path, args.reverse_digraph) end = time.time() print('Time to read graph:', end-start) if args.graph_attribute_parsing: start = time.time() non_fluent_facts_node, non_fluent_facts_edge, specific_node_labels, specific_edge_labels = graphml_parser.parse_graph_attributes(args.static_graph_facts) end = time.time() print('Time to read graph attributes:', end-start) else: non_fluent_facts_node = [] non_fluent_facts_edge = [] tmax = args.timesteps # Initialize labels node_labels, edge_labels, snl, sel = yaml_parser.parse_labels(args.labels_yaml_path) if args.graph_attribute_parsing: specific_node_labels.update(snl) specific_edge_labels.update(sel) for label_name, nodes in specific_node_labels.items(): if label_name in snl: snl[label_name].extend(nodes) else: snl[label_name] = nodes for label_name, edges in specific_edge_labels.items(): if label_name in sel: sel[label_name].extend(edges) else: sel[label_name] = edges else: specific_node_labels = snl specific_edge_labels = sel # Rules come here rules = yaml_parser.parse_rules(args.rules_yaml_path) # Facts come here. Add non fluent facts that came from the graph facts_node, facts_edge = yaml_parser.parse_facts(args.facts_yaml_path, args.reverse_digraph) facts_node += non_fluent_facts_node facts_edge += non_fluent_facts_edge # Inconsistent predicate list ipl = yaml_parser.parse_ipl(args.ipl_yaml_path) # Program comes here program = Program(graph, facts_node, facts_edge, rules, ipl, args.reverse_digraph, args.atom_trace, args.save_graph_attributes_to_trace, args.canonical, args.inconsistency_check) program.available_labels_node = node_labels program.available_labels_edge = edge_labels program.specific_node_labels = specific_node_labels program.specific_edge_labels = specific_edge_labels # Reasoning process print('Graph loaded successfully, rules, labels, facts and ipl parsed successfully') print('Starting diffusion') start = time.time() interpretation = program.reason(tmax, args.convergence_threshold, args.convergence_bound_threshold) end = time.time() print('Time to complete diffusion:', end-start) print('Finished diffusion') # Save the rule trace to a file output = Output(timestamp) output.save_rule_trace(interpretation, folder='./output') # This is how you filter the dataframe to show only nodes that have success in a certain interval print('Filtering data...') filterer = Filter(interpretation.time) filtered_df = filterer.filter_and_sort_nodes(interpretation, labels=args.filter_labels, bound=interval.closed(0, 1), sort_by=args.filter_sort_by, descending=args.descending) # You can index into filtered_df to get a particular timestep # This is for each timestep for t in range(interpretation.time+1): print(f'\n TIMESTEP - {t}') print(filtered_df[t]) print()
if __name__ == "__main__":
[docs] args = argparser()
if args.profile: profiler = cProfile.Profile() profiler.enable() main(args) profiler.disable() s = io.StringIO() stats = pstats.Stats(profiler, stream=s).sort_stats('tottime') stats.print_stats() with open('./profiling/' + args.profile_output + '.txt', 'w+') as f: f.write(s.getvalue()) else: if args.memory_profile: start_mem = mp.memory_usage(max_usage=True) mem_usage = mp.memory_usage((main, [args]), max_usage=True) print(f"\nProgram used {mem_usage-start_mem} MB of memory") else: main(args)