"""This module implements the object model for EMMAA model testing."""
import logging
import itertools
import jsonpickle
import os
import sys
from collections import defaultdict
from fnvhash import fnv1a_32
from urllib import parse
from copy import deepcopy
from indra.explanation.model_checker import PysbModelChecker, \
PybelModelChecker, SignedGraphModelChecker, UnsignedGraphModelChecker
from indra.explanation.reporting import stmts_from_pysb_path, \
stmts_from_pybel_path, stmts_from_indranet_path, PybelEdge, \
pybel_edge_to_english, RefEdge
from indra.explanation.pathfinding import bfs_search_multiple_nodes
from indra.assemblers.english.assembler import EnglishAssembler
from indra.statements import Statement, Agent, stmts_to_json
from indra.util.statement_presentation import group_and_sort_statements, \
make_string_from_relation_key
from indra.ontology.bio import bio_ontology
from bioagents.tra.tra import TRA, MissingMonomerError, MissingMonomerSiteError
from emmaa.model import EmmaaModel, get_assembled_statements, \
load_config_from_s3
from emmaa.statements import filter_indra_stmts_by_metadata
from emmaa.queries import PathProperty, DynamicProperty, OpenSearchQuery, \
SimpleInterventionProperty
from emmaa.util import make_date_str, get_s3_client, \
EMMAA_BUCKET_NAME, find_latest_s3_file, load_pickle_from_s3, \
save_pickle_to_s3, load_json_from_s3, save_json_to_s3, strip_out_date, \
save_gzip_json_to_s3
from emmaa.filter_functions import node_filter_functions, edge_filter_functions
logger = logging.getLogger(__name__)
sys.setrecursionlimit(50000)
result_codes_link = ('https://emmaa.readthedocs.io/en/latest/dashboard/'
'response_codes.html')
RESULT_CODES = {
'STATEMENT_TYPE_NOT_HANDLED': 'Statement type not handled',
'SUBJECT_MONOMERS_NOT_FOUND': 'Statement subject not in model',
'SUBJECT_NOT_FOUND': 'Statement subject not in model',
'OBSERVABLES_NOT_FOUND': 'Statement object state not in model',
'OBJECT_NOT_FOUND': 'Statement object state not in model',
'NO_PATHS_FOUND': 'No path found that satisfies the test statement',
'MAX_PATH_LENGTH_EXCEEDED': 'Path found but exceeds search depth',
'PATHS_FOUND': 'Path found which satisfies the test statement',
'INPUT_RULES_NOT_FOUND': 'No rules with test statement subject',
'MAX_PATHS_ZERO': 'Path found but not reconstructed',
'QUERY_NOT_APPLICABLE': 'Query is not applicable for this model',
'NODE_NOT_FOUND': 'Node not in model'
}
ARROW_DICT = {'Complex': u"\u2194",
'Inhibition': u"\u22A3",
'DecreaseAmount': u"\u22A3"}
# This mapping configures the use of different model types
# path: regular tests and path-based queries will be run against these models
# simulation: priority list for running simulation based queries (only first
# available model type will be used)
MODEL_TYPES = {'path': ['pysb', 'pybel', 'signed_graph', 'unsigned_graph'],
'simulation': ['dynamic', 'pysb']}
[docs]class ModelManager(object):
"""Manager to generate and store properties of a model and relevant tests.
Parameters
----------
model : emmaa.model.EmmaaModel
EMMAA model
mode : str
If 'local' (default), does not save any exports/images to S3. It is
only set to 's3' mode in update_model_manager.py script.
Attributes
----------
mc_mapping : dict
A dictionary mapping a ModelChecker type to a corresponding method
for assembling the model and a ModelChecker class.
mc_types : dict
A dictionary in which each key is a type of a ModelChecker and value is
a dictionary containing an instance of a model, an instance of a
ModelChecker and a list of test results.
entities : list[indra.statements.agent.Agent]
A list of entities of EMMAA model.
applicable_tests : list[emmaa.model_tests.EmmaaTest]
A list of EMMAA tests applicable for given EMMAA model.
date_str : str
Time when this object was created.
path_stmt_types : dict
A dictionary mapping statement hashes to a count of paths they are in.
"""
def __init__(self, model, mode='local'):
self.model = model
self.mode = mode
self.mc_mapping = {
'pysb': (self.model.assemble_pysb, PysbModelChecker,
stmts_from_pysb_path),
'pybel': (self.model.assemble_pybel, PybelModelChecker,
stmts_from_pybel_path),
'signed_graph': (self.model.assemble_signed_graph,
SignedGraphModelChecker,
stmts_from_indranet_path),
'unsigned_graph': (self.model.assemble_unsigned_graph,
UnsignedGraphModelChecker,
stmts_from_indranet_path),
'dynamic': (self.model.assemble_dynamic_pysb, None, None)}
self.mc_types = {}
for mc_type in model.test_config.get('mc_types', ['pysb']):
self.mc_types[mc_type] = {}
assembled_model = self.mc_mapping[mc_type][0](mode=mode)
self.mc_types[mc_type]['model'] = assembled_model
if mc_type in MODEL_TYPES['path']:
self.mc_types[mc_type]['model_checker'] = (
self.mc_mapping[mc_type][1](assembled_model))
self.mc_types[mc_type]['test_results'] = []
self.entities = self.model.get_assembled_entities()
self.applicable_tests = []
self.date_str = self.model.date_str
self.path_stmt_counts = defaultdict(int)
@classmethod
def load_from_statements(cls, model_name, mode='local', date=None,
bucket=EMMAA_BUCKET_NAME):
config = load_config_from_s3(model_name, bucket=bucket)
if date:
prefix = f'papers/{model_name}/paper_ids_{date}'
else:
prefix = f'papers/{model_name}/paper_ids_'
paper_key = find_latest_s3_file(bucket, prefix, 'json')
if paper_key:
paper_ids = load_json_from_s3(bucket, paper_key)
else:
paper_ids = None
model = EmmaaModel(model_name, config, paper_ids)
# Loading assembled statements to avoid reassembly
stmts, fname = get_assembled_statements(model_name, date, bucket)
model.assembled_stmts = stmts
model.date_str = strip_out_date(fname, 'datetime')
mm = cls(model, mode=mode)
return mm
[docs] def get_updated_mc(self, mc_type, stmts, add_ns=False,
edge_filter_func=None):
"""Update the ModelChecker and graph with stmts for tests/queries."""
mc = self.mc_types[mc_type]['model_checker']
mc.statements = stmts
if mc_type == 'pysb':
mc.graph = None
mc.model_stmts = self.model.assembled_stmts
mc.get_graph(prune_im=True, prune_im_degrade=True,
add_namespaces=add_ns,
edge_filter_func=edge_filter_func)
else:
mc.graph = None
mc.get_graph(edge_filter_func=edge_filter_func)
if mc_type in ('signed_graph', 'unsigned_graph'):
mc.nodes_to_agents = {ag.name: ag for ag in self.entities}
return mc
[docs] def add_test(self, test):
"""Add a test to a list of applicable tests."""
self.applicable_tests.append(test)
[docs] def add_result(self, mc_type, result):
"""Add a result to a list of results."""
self.mc_types[mc_type]['test_results'].append(result)
[docs] def run_all_tests(self, filter_func=None, edge_filter_func=None):
"""Run all applicable tests with all available ModelCheckers."""
max_path_length, max_paths = self._get_test_configs()
for mc_type in self.mc_types:
if mc_type not in MODEL_TYPES['path']:
continue
self.run_tests_per_mc(mc_type, max_path_length, max_paths,
filter_func, edge_filter_func)
[docs] def run_tests_per_mc(self, mc_type, max_path_length, max_paths,
filter_func=None, edge_filter_func=None):
"""Run all applicable tests with one ModelChecker."""
mc = self.get_updated_mc(
mc_type, [test.stmt for test in self.applicable_tests],
edge_filter_func=edge_filter_func)
logger.info(f'Running the tests with {mc_type} ModelChecker.')
if filter_func:
logger.info(f'Applying {filter_func.__name__}')
results = mc.check_model(
max_path_length=max_path_length, max_paths=max_paths,
agent_filter_func=filter_func, edge_filter_func=edge_filter_func)
for (stmt, result) in results:
self.add_result(mc_type, result)
def make_path_json(self, mc_type, result_paths):
paths = []
json_lines = []
for path in result_paths:
path_nodes = []
edge_list = []
path_node_list = []
hashes = []
report_function = self.mc_mapping[mc_type][2]
model = self.mc_types[mc_type]['model']
stmts = self.model.assembled_stmts
if mc_type == 'pysb':
report_stmts = report_function(path, model, stmts)
path_stmts = [[st] for st in report_stmts]
merge = False
elif mc_type == 'pybel':
path_stmts = report_function(path, model, False, stmts)
merge = False
elif mc_type == 'signed_graph':
path_stmts = report_function(path, model, True, False, stmts)
merge = True
elif mc_type == 'unsigned_graph':
path_stmts = report_function(path, model, False, False, stmts)
merge = True
for i, step in enumerate(path_stmts):
edge_nodes = []
if len(step) < 1:
continue
stmt_type = type(step[0]).__name__
# Skip reporting has component edges
if stmt_type == 'PybelEdge' and step[0].relation == 'partOf' \
and step[0].reverse:
continue
elif stmt_type in ('PybelEdge', 'RefEdge'):
source, target = step[0].source, step[0].target
edge_nodes.append(source.name)
edge_nodes.append(u"\u2192")
edge_nodes.append(target.name)
hashes.append({'type': stmt_type})
else:
step_hashes = []
for stmt in step:
self.path_stmt_counts[stmt.get_hash()] += 1
step_hashes.append(stmt.get_hash())
hashes.append({'type': 'statements',
'hashes': step_hashes})
agents = [ag.name if ag is not None else None
for ag in step[0].agent_list()]
# For complexes make sure that the agent from the
# previous edge goes first
if stmt_type == 'Complex' and len(path_nodes) > 0:
agents = sorted(
[ag for ag in agents if ag is not None],
key=lambda x: x != path_nodes[-1])
for j, ag in enumerate(agents):
if ag is not None:
edge_nodes.append(ag)
if j == (len(agents) - 1):
break
if stmt_type in ARROW_DICT:
edge_nodes.append(ARROW_DICT[stmt_type])
else:
edge_nodes.append(u"\u2192")
if i == 0:
for n in edge_nodes:
path_nodes.append(n)
path_node_list.append(edge_nodes[0])
path_node_list.append(edge_nodes[-1])
else:
for n in edge_nodes[1:]:
path_nodes.append(n)
path_node_list.append(edge_nodes[-1])
step_sentences = self._make_path_stmts(step, merge=merge)
edge_dict = {'edge': ' '.join(edge_nodes),
'stmts': step_sentences}
edge_list.append(edge_dict)
path_json = {'path': ' '.join(path_nodes),
'edge_list': edge_list}
one_line_path_json = {'nodes': path_node_list, 'edges': hashes,
'graph_type': mc_type}
paths.append(path_json)
json_lines.append(one_line_path_json)
return paths, json_lines
def _make_path_stmts(self, stmts, merge=False):
sentences = []
date = strip_out_date(self.date_str, 'date')
if merge and isinstance(stmts[0], Statement):
groups = group_and_sort_statements(stmts, grouping_level='relation')
for _, rel_key, group_stmts, _ in groups:
sentence = make_string_from_relation_key(rel_key) + '.'
stmt_hashes = [gr_st.get_hash()
for _, _, gr_st, _ in group_stmts]
url_param = parse.urlencode(
{'stmt_hash': stmt_hashes, 'source': 'model_statement',
'model': self.model.name, 'date': date}, doseq=True)
link = f'/evidence?{url_param}'
sentences.append((link, sentence, ''))
else:
for stmt in stmts:
if isinstance(stmt, PybelEdge):
sentence = pybel_edge_to_english(stmt)
sentences.append(('', sentence, ''))
elif isinstance(stmt, RefEdge):
sentence = stmt.to_english()
sentences.append(('', sentence, ''))
else:
ea = EnglishAssembler([stmt])
sentence = ea.make_model()
stmt_hashes = [stmt.get_hash()]
url_param = parse.urlencode(
{'stmt_hash': stmt_hashes, 'source': 'model_statement',
'model': self.model.name, 'date': date}, doseq=True)
link = f'/evidence?{url_param}'
sentences.append((link, sentence, ''))
return sentences
def make_result_code(self, result):
result_code = result.result_code
return RESULT_CODES[result_code]
def answer_query(self, query, **kwargs):
if isinstance(query, DynamicProperty):
return self.answer_dynamic_query(query, **kwargs)
if isinstance(query, PathProperty):
return self.answer_path_query(query)
if isinstance(query, OpenSearchQuery):
return self.answer_open_query(query)
if isinstance(query, SimpleInterventionProperty):
return self.answer_intervention_query(query)
[docs] def answer_path_query(self, query):
"""Answer user query with a path if it is found."""
if ScopeTestConnector.applicable(self, query):
results = []
for mc_type in self.mc_types:
if mc_type not in MODEL_TYPES['path']:
continue
mc = self.get_updated_mc(mc_type, [query.path_stmt])
max_path_length, max_paths = self._get_test_configs(
mode='query', mc_type=mc_type, default_paths=5)
result = mc.check_statement(
query.path_stmt, max_paths, max_path_length)
hashed_res, path_lines = self.process_response(mc_type, result)
results.append((mc_type, hashed_res, path_lines))
return results
else:
return [('', self.hash_response_list(
RESULT_CODES['QUERY_NOT_APPLICABLE']),
[{'fail_reason': RESULT_CODES['QUERY_NOT_APPLICABLE']}])]
[docs] def answer_dynamic_query(self, query, bucket=EMMAA_BUCKET_NAME):
"""Answer user query by simulating a PySB model."""
pysb_model, use_kappa, time_limit, num_times, num_sim, hyp_tester = \
self._get_dynamic_components('dynamic')
tra = TRA(use_kappa=use_kappa)
tp = query.get_temporal_pattern(time_limit)
try:
sat_rate, num_sim, kpat, pat_obj, fig_path = tra.check_property(
pysb_model, tp, num_times=num_times, num_sim=num_sim,
hypothesis_tester=hyp_tester)
if self.mode == 's3':
fig_name, ext = os.path.splitext(os.path.basename(fig_path))
date_str = make_date_str()
s3_key = (f'query_images/{self.model.name}/{fig_name}_'
f'{date_str}{ext}')
s3_path = f'https://{bucket}.s3.amazonaws.com/{s3_key}'
client = get_s3_client(unsigned=False)
logger.info(f'Uploading image to {s3_path}')
client.upload_file(fig_path, Bucket=bucket, Key=s3_key)
fig_path = s3_path
resp_json = {'sat_rate': sat_rate, 'num_sim': num_sim,
'kpat': kpat, 'fig_path': fig_path}
return [('pysb', self.hash_response_list(resp_json),resp_json)]
except (MissingMonomerError, MissingMonomerSiteError):
resp_json = RESULT_CODES['QUERY_NOT_APPLICABLE']
return [('pysb', self.hash_response_list(resp_json),
{'fail_reason': RESULT_CODES['QUERY_NOT_APPLICABLE']})]
[docs] def answer_intervention_query(self, query, bucket=EMMAA_BUCKET_NAME):
"""Answer user intervention query by simulating a PySB model."""
pysb_model, use_kappa, time_limit, num_times, num_sim, _ = \
self._get_dynamic_components('intervention')
tra = TRA(use_kappa=use_kappa)
try:
res, fig_path = tra.compare_conditions(pysb_model,
query.condition_entity,
query.target_entity,
query.direction,
time_limit, num_times)
if self.mode == 's3':
fig_name, ext = os.path.splitext(os.path.basename(fig_path))
date_str = make_date_str()
s3_key = (f'query_images/{self.model.name}/{fig_name}_'
f'{date_str}{ext}')
s3_path = f'https://{bucket}.s3.amazonaws.com/{s3_key}'
client = get_s3_client(unsigned=False)
logger.info(f'Uploading image to {s3_path}')
client.upload_file(fig_path, Bucket=bucket, Key=s3_key)
fig_path = s3_path
resp_json = {'result': res, 'fig_path': fig_path}
return [('pysb', self.hash_response_list(resp_json), resp_json)]
except (MissingMonomerError, MissingMonomerSiteError):
resp_json = RESULT_CODES['QUERY_NOT_APPLICABLE']
return [('pysb', self.hash_response_list(resp_json),
{'fail_reason': RESULT_CODES['QUERY_NOT_APPLICABLE']})]
[docs] def answer_open_query(self, query):
"""Answer user open search query with found paths."""
if ScopeTestConnector.applicable(self, query):
results = []
for mc_type in self.mc_types:
if mc_type not in MODEL_TYPES['path']:
continue
max_path_length, max_paths = self._get_test_configs(
mode='query', qtype='open_search', mc_type=mc_type,
default_paths=50, default_length=2)
add_ns = False
if query.terminal_ns:
add_ns = True
mc = self.get_updated_mc(mc_type, [query.path_stmt], add_ns)
res, paths = self.open_query_per_mc(
mc_type, mc, query, max_path_length, max_paths)
results.append((mc_type, res, paths))
return results
else:
return [('', self.hash_response_list(
RESULT_CODES['QUERY_NOT_APPLICABLE']),
[{'fail_reason': RESULT_CODES['QUERY_NOT_APPLICABLE']}])]
def open_query_per_mc(self, mc_type, mc, query, max_path_length,
max_paths):
g = mc.get_graph()
subj_nodes, obj_nodes, res_code = mc.process_statement(query.path_stmt)
if res_code:
return self.hash_response_list(RESULT_CODES[res_code]), \
[{'fail_reason': RESULT_CODES[res_code]}]
else:
if query.entity_role == 'subject':
reverse = False
assert subj_nodes.all_nodes
nodes = subj_nodes.all_nodes
else:
reverse = True
assert obj_nodes
nodes = obj_nodes.all_nodes
sign = query.get_sign(mc_type)
if mc_type == 'pysb':
terminal_ns = None
else:
terminal_ns = query.terminal_ns
paths_gen = bfs_search_multiple_nodes(
g, nodes, reverse=reverse, terminal_ns=terminal_ns,
depth_limit=max_path_length, path_limit=max_paths, sign=sign)
paths = []
for p in paths_gen:
if reverse:
paths.append(p[::-1])
else:
paths.append(p)
return self.process_open_query_response(mc_type, paths)
[docs] def answer_queries(self, queries, **kwargs):
"""Answer all queries registered for this model.
Parameters
----------
queries : list[emmaa.queries.Query]
A list of queries to run.
Returns
-------
responses : list[tuple(json, json)]
A list of tuples each containing a query, mc_type and result json.
"""
responses = []
applicable_queries = []
applicable_stmts = []
applicable_open_queries = []
applicable_open_stmts = []
for query in queries:
# Dynamic queries need to be answered individually, while for
# path and open queries some parts can be shared
if isinstance(query, DynamicProperty):
mc_type, response, resp_json = self.answer_dynamic_query(
query, **kwargs)[0]
responses.append((query, mc_type, response))
if isinstance(query, SimpleInterventionProperty):
mc_type, response, resp_json = self.answer_intervention_query(
query, **kwargs)[0]
responses.append((query, mc_type, response))
elif isinstance(query, PathProperty):
if ScopeTestConnector.applicable(self, query):
applicable_queries.append(query)
applicable_stmts.append(query.path_stmt)
else:
responses.append(
(query, '', self.hash_response_list(
RESULT_CODES['QUERY_NOT_APPLICABLE'])))
elif isinstance(query, OpenSearchQuery):
if ScopeTestConnector.applicable(self, query):
applicable_open_queries.append(query)
applicable_open_stmts.append(query.path_stmt)
else:
responses.append(
(query, '', self.hash_response_list(
RESULT_CODES['QUERY_NOT_APPLICABLE'])))
# Only do the following steps if there are applicable queries
# Path queries
if applicable_queries:
for mc_type in self.mc_types:
if mc_type not in MODEL_TYPES['path']:
continue
mc = self.get_updated_mc(mc_type, applicable_stmts)
max_path_length, max_paths = self._get_test_configs(
mode='query', mc_type=mc_type, default_paths=5)
results = mc.check_model(
max_path_length=max_path_length, max_paths=max_paths)
for ix, (_, result) in enumerate(results):
resp, paths = self.process_response(mc_type, result)
responses.append(
(applicable_queries[ix], mc_type, resp))
# Open queries
if applicable_open_queries:
for mc_type in self.mc_types:
if mc_type not in MODEL_TYPES['path']:
continue
max_path_length, max_paths = self._get_test_configs(
mode='query', qtype='open_search', mc_type=mc_type,
default_paths=50, default_length=2)
mc = self.get_updated_mc(mc_type, applicable_open_stmts, True)
for query in applicable_open_queries:
res, paths = self.open_query_per_mc(
mc_type, mc, query, max_path_length, max_paths)
responses.append((query, mc_type, res))
return sorted(responses, key=lambda x: x[0].matches_key())
def _get_test_configs(self, mode='test', qtype='statement_checking',
mc_type=None, default_length=5, default_paths=1):
if mode == 'test':
config = self.model.test_config
elif mode == 'query':
config = self.model.query_config
try:
max_path_length = \
config[qtype][mc_type]['max_path_length']
except KeyError:
try:
max_path_length = \
config[qtype]['max_path_length']
except KeyError:
max_path_length = default_length
try:
max_paths = \
config[qtype][mc_type]['max_paths']
except KeyError:
try:
max_paths = \
config[qtype]['max_paths']
except KeyError:
max_paths = default_paths
logger.info('Parameters for model checking: %d, %d' %
(max_path_length, max_paths))
return (max_path_length, max_paths)
def _get_dynamic_components(self, qtype='dynamic'):
# Get simulation mode (kappa or ODE) from query config
use_kappa = False
time_limit = None
num_times = 100
num_sim = 2
hyp_tester = None
if qtype in self.model.query_config:
qc = self.model.query_config[qtype]
use_kappa = qc.get('use_kappa', False)
time_limit = qc.get('time_limit')
num_times = qc.get('num_times', 100)
# If we have a fixed number of simulations, we use that
if 'num_sim' in qc:
num_sim = qc['num_sim']
hyp_tester = None
# If we have parameters for a hypothesis tester, we use that
elif 'hypothesis_tester' in qc:
from bioagents.tra.model_checker import HypothesisTester
num_sim = 0
hyp_tester = HypothesisTester(**qc['hypothesis_tester'])
# If we don't have any specification, we fall back on 2 fixed
# simulations
else:
num_sim = 2
hyp_tester = None
# Either use specially assembled or regular PySB depending on model
for mc_type in MODEL_TYPES['simulation']:
if mc_type in self.mc_types:
logger.info(f'Using {mc_type} model for simulation')
pysb_model = deepcopy(self.mc_types[mc_type]['model'])
break
return pysb_model, use_kappa, time_limit, num_times, num_sim, hyp_tester
[docs] def process_response(self, mc_type, result):
"""Return a dictionary in which every key is a hash and value is a list
of tuples. Each tuple contains a sentence describing either a step in a
path (if it was found) or result code (if a path was not found) and a
link leading to a webpage with more information about corresponding
sentence.
"""
if result.paths:
response, path_lines = self.make_path_json(mc_type, result.paths)
return self.hash_response_list(response), path_lines
else:
response = self.make_result_code(result)
return self.hash_response_list(response), [{'fail_reason': response}]
def process_open_query_response(self, mc_type, paths):
if paths:
response, path_lines = self.make_path_json(mc_type, paths)
return self.hash_response_list(response), path_lines
else:
response = 'No paths found that satisfy this query'
return self.hash_response_list(response), [{'fail_reason': response}]
[docs] def hash_response_list(self, response):
"""Return a dictionary mapping a hash with a response in a response
list.
"""
response_dict = {}
if isinstance(response, str):
response_hash = str(fnv1a_32(response.encode('utf-8')))
response_dict[response_hash] = response
elif isinstance(response, list):
for resp in response:
sentences = []
for edge in resp['edge_list']:
for (_, sentence, _) in edge['stmts']:
sentences.append(sentence)
response_str = ' '.join(sentences)
response_hash = str(fnv1a_32(response_str.encode('utf-8')))
response_dict[response_hash] = resp
elif isinstance(response, dict):
if 'sat_rate' in response:
results = [str(response.get('sat_rate')),
str(response.get('num_sim'))]
response_str = ' '.join(results)
else:
response_str = response.get('result')
response_hash = str(fnv1a_32(response_str.encode('utf-8')))
response_dict[response_hash] = response
else:
raise TypeError('Response should be a string or a list.')
return response_dict
[docs] def results_to_json(self, test_data=None):
"""Put test results to json format."""
pickler = jsonpickle.pickler.Pickler()
results_json = []
results_json.append({
'model_name': self.model.name,
'mc_types': [mc_type for mc_type in self.mc_types
if mc_type in MODEL_TYPES['path']],
'path_stmt_counts': self.path_stmt_counts,
'date_str': self.date_str,
'test_data': test_data})
json_lines = []
for ix, test in enumerate(self.applicable_tests):
test_ix_results = {'test_type': test.__class__.__name__,
'test_json': test.to_json()}
for mc_type in self.mc_types:
if mc_type not in MODEL_TYPES['path']:
continue
result = self.mc_types[mc_type]['test_results'][ix]
path_json, test_json_lines = self.make_path_json(
mc_type, result.paths)
test_ix_results[mc_type] = {
'result_json': pickler.flatten(result),
'path_json': path_json,
'result_code': self.make_result_code(result)}
for line in test_json_lines:
# Only include lines with paths
if line:
line.update({'test': test.stmt.get_hash()})
json_lines.append(line)
results_json.append(test_ix_results)
return results_json, json_lines
[docs] def upload_results(self, test_corpus='large_corpus_tests',
test_data=None, bucket=EMMAA_BUCKET_NAME):
"""Upload results to s3 bucket."""
json_dict, json_lines = self.results_to_json(test_data)
result_key = (f'results/{self.model.name}/results_'
f'{test_corpus}_{self.date_str}.json')
paths_key = (f'paths/{self.model.name}/paths_{test_corpus}_'
f'{self.date_str}.jsonl')
latest_paths_key = (f'paths/{self.model.name}/{test_corpus}'
'_latest_paths.jsonl')
logger.info(f'Uploading test results to {result_key}')
save_json_to_s3(json_dict, bucket, result_key)
logger.info(f'Uploading test paths to {paths_key}')
save_json_to_s3(json_lines, bucket, paths_key, save_format='jsonl')
save_json_to_s3(json_lines, bucket, latest_paths_key, 'jsonl')
[docs] def save_assembled_statements(self, bucket=EMMAA_BUCKET_NAME):
"""Upload assembled statements jsons to S3 bucket."""
def save_stmts(stmts, model_name):
stmts_json = stmts_to_json(stmts)
# Save a timestapmed version and a generic latest version of files
dated_key = f'assembled/{model_name}/statements_{self.date_str}'
latest_key = f'assembled/{model_name}/' \
f'latest_statements_{model_name}'
for ext in ('json', 'jsonl'):
latest_obj_key = latest_key + '.' + ext
logger.info('Uploading assembled statements to '
f'{latest_obj_key}')
save_json_to_s3(stmts_json, bucket, latest_obj_key, ext)
dated_jsonl = dated_key + '.jsonl'
dated_zip = dated_key + '.gz'
logger.info(f'Uploading assembled statements to {dated_jsonl}')
save_json_to_s3(stmts_json, bucket, dated_jsonl, 'jsonl')
logger.info(f'Uploading assembled statements to {dated_zip}')
save_gzip_json_to_s3(stmts_json, bucket, dated_zip, 'json')
save_stmts(self.model.assembled_stmts, self.model.name)
if hasattr(self.model, 'dynamic_assembled_stmts') and \
self.model.dynamic_assembled_stmts:
save_stmts(self.model.dynamic_assembled_stmts,
f'{self.model.name}_dynamic')
[docs]class TestManager(object):
"""Manager to generate and run a set of tests on a set of models.
Parameters
----------
model_managers : list[emmaa.model_tests.ModelManager]
A list of ModelManager objects
tests : list[emmaa.model_tests.EmmaaTest]
A list of EMMAA tests
"""
def __init__(self, model_managers, tests):
self.model_managers = model_managers
self.tests = tests
[docs] def make_tests(self, test_connector):
"""Generate a list of applicable tests for each model with a given test
connector.
Parameters
----------
test_connector : emmaa.model_tests.TestConnector
A TestConnector object to use for connecting models to tests.
"""
logger.info(f'Checking applicability of {len(self.tests)} tests to '
f'{len(self.model_managers)} models')
for model_manager, test in itertools.product(self.model_managers,
self.tests):
if test_connector.applicable(model_manager, test):
model_manager.add_test(test)
logger.debug(f'Test {test.stmt} is applicable')
else:
logger.debug(f'Test {test.stmt} is not applicable')
logger.info(f'Created tests for {len(self.model_managers)} models.')
for model_manager in self.model_managers:
logger.info(f'Created {len(model_manager.applicable_tests)} tests '
f'for {model_manager.model.name} model.')
[docs] def run_tests(self, filter_func=None, edge_filter_func=None):
"""Run tests for a list of model-test pairs"""
for model_manager in self.model_managers:
model_manager.run_all_tests(filter_func, edge_filter_func)
[docs]class TestConnector(object):
"""Determines if a given test is applicable to a given model."""
def __init__(self):
pass
[docs] @staticmethod
def applicable(model, test):
"""Return True if the test is applicable to the given model."""
return True
[docs]class ScopeTestConnector(TestConnector):
"""Determines applicability of a test to a model by overlap in scope."""
[docs] @staticmethod
def applicable(model, test):
"""Return True of all test entities are in the set of model entities"""
model_entities = model.entities
test_entities = test.get_entities()
return ScopeTestConnector._overlap(model_entities, test_entities)
@staticmethod
def _overlap(model_entities, test_entities):
me_names = {e.name for e in model_entities}
te_names = {e.name for e in test_entities}
# If all test entities are in model entities, we get an empty set here
# so we return True
return not te_names - me_names
[docs]class RefinementTestConnector(TestConnector):
"""Determines applicability of a test to a model by checking if test
entities or their refinements are in the model.
"""
[docs] @staticmethod
def applicable(model, test):
"""Return True of all test entities are in the set of model entities"""
model_entities = model.entities
test_entities = test.get_entities()
test_entity_groups = []
for te in test_entities:
te_group = [te]
ns, gr = te.get_grounding()
children = bio_ontology.get_children(ns, gr)
for ns, gr in children:
name = bio_ontology.get_name(ns, gr)
ag = Agent(name, db_refs={ns: gr})
te_group.append(ag)
test_entity_groups.append(te_group)
return RefinementTestConnector._overlap(model_entities,
test_entity_groups)
@staticmethod
def _ref_group_overlap(model_entities, test_entity_group):
me_names = {e.name for e in model_entities}
te_names = {e.name for e in test_entity_group}
# We need at least one intersection between these groups
return me_names.intersection(te_names)
@staticmethod
def _overlap(model_entities, test_entity_groups):
# We need to get overlap with each test entity group
return all([RefinementTestConnector._ref_group_overlap(
model_entities, te_group) for te_group in test_entity_groups])
[docs]class EmmaaTest(object):
"""Represent an EMMAA test condition"""
[docs] def get_entities(self):
"""Return a list of entities that the test checks for."""
raise NotImplementedError()
[docs]class StatementCheckingTest(EmmaaTest):
"""Represent an EMMAA test condition that checks a PySB-assembled model
against an INDRA Statement."""
def __init__(self, stmt, configs=None):
self.stmt = stmt
self.configs = {} if not configs else configs
# TODO
# Add entities as a property if we can reload tests on s3.
# self.entities = self.get_entities()
[docs] def check(self, model_checker, pysb_model):
"""Use a model checker to check if a given model satisfies the test."""
max_path_length = self.configs.get('max_path_length', 5)
max_paths = self.configs.get('max_paths', 1)
logger.info('Parameters for model checking: %s, %d' %
(max_path_length, max_paths))
res = model_checker.check_statement(
self.stmt,
max_path_length=max_path_length,
max_paths=max_paths)
return res
[docs] def get_entities(self):
"""Return a list of entities that the test checks for."""
return self.stmt.agent_list()
def to_json(self):
return self.stmt.to_json()
def __repr__(self):
return "%s(stmt=%s)" % (self.__class__.__name__, repr(self.stmt))
[docs]def load_tests_from_s3(test_name, bucket=EMMAA_BUCKET_NAME):
"""Load Emmaa Tests with the given name from S3.
Parameters
----------
test_name : str
Looks for a test file in the emmaa bucket on S3 with key
'tests/{test_name}'.
Return
------
list of EmmaaTest
List of EmmaaTest objects loaded from S3.
"""
prefix = f'tests/{test_name}'
try:
test_key = find_latest_s3_file(bucket, prefix, '.pkl')
except ValueError:
test_key = f'tests/{test_name}.pkl'
logger.info(f'Loading tests from {test_key}')
tests = load_pickle_from_s3(bucket, test_key)
return tests, test_key
def save_model_manager_to_s3(model_name, model_manager,
bucket=EMMAA_BUCKET_NAME):
logger.info(f'Saving a model manager for {model_name} model to S3.')
date_str = model_manager.date_str
model_manager.model.stmts = []
model_manager.model.assembled_stmts = []
model_manager.model.dynamic_assembled_stmts = []
save_pickle_to_s3(model_manager, bucket,
f'results/{model_name}/model_manager_{date_str}.pkl')
def load_model_manager_from_s3(model_name=None, key=None,
bucket=EMMAA_BUCKET_NAME):
# First try find the file from specified key
if key:
try:
model_manager = load_pickle_from_s3(bucket, key)
if not model_manager.model.assembled_stmts:
stmts, _ = get_assembled_statements(
model_manager.model.name,
strip_out_date(model_manager.date_str, 'date'),
bucket=bucket)
model_manager.model.assembled_stmts = stmts
return model_manager
except Exception as e:
logger.info('Could not load the model manager directly')
logger.info(e)
if not model_name:
model_name = key.split('/')[1]
date = strip_out_date(key, 'date')
logger.info('Trying to load model manager from statements')
try:
model_manager = ModelManager.load_from_statements(
model_name, date=date, bucket=bucket)
return model_manager
except Exception as e:
logger.info('Could not load the model manager from '
'statements')
logger.info(e)
return None
# Now try find the latest key for given model
if model_name:
# Versioned
key = find_latest_s3_file(
bucket, f'results/{model_name}/model_manager_', '.pkl')
if key is None:
# Non-versioned
key = f'results/{model_name}/latest_model_manager.pkl'
return load_model_manager_from_s3(model_name=model_name, key=key,
bucket=bucket)
# Could not find either from key or from model name.
logger.info('Could not find the model manager.')
return None
def update_model_manager_on_s3(model_name, bucket=EMMAA_BUCKET_NAME):
model = EmmaaModel.load_from_s3(model_name, bucket=bucket)
mm = ModelManager(model)
save_model_manager_to_s3(model_name, mm, bucket=bucket)
return mm
[docs]def model_to_tests(model_name, upload=True, bucket=EMMAA_BUCKET_NAME):
"""Create StatementCheckingTests from model statements."""
stmts, _ = get_assembled_statements(model_name, bucket=bucket)
config = load_config_from_s3(model_name, bucket=bucket)
# Filter statements if needed
if isinstance(config.get('make_tests'), dict):
conditions = config['make_tests']['filter']['conditions']
evid_policy = config['make_tests']['filter']['evid_policy']
stmts = filter_indra_stmts_by_metadata(stmts, conditions, evid_policy)
tests = [StatementCheckingTest(stmt) for stmt in stmts if
all(stmt.agent_list())]
date_str = make_date_str()
test_description = (
f'These tests were generated from the '
f'{config.get("human_readable_name")} on {date_str[:10]}')
test_name = f'{config.get("human_readable_name")} model test corpus'
test_dict = {'test_data': {'description': test_description,
'name': test_name},
'tests': tests}
if upload:
save_tests_to_s3(test_dict, bucket,
f'tests/{model_name}_tests_{date_str}.pkl', 'pkl')
return test_dict
[docs]def save_tests_to_s3(tests, bucket, key, save_format='pkl'):
"""Save tests in pkl, json or jsonl format."""
if save_format == 'pkl':
save_pickle_to_s3(tests, bucket, key)
elif save_format in ['json', 'jsonl']:
if isinstance(tests, list):
stmts = [test.stmt for test in tests]
elif isinstance(tests, dict):
stmts = [test.stmt for test in tests['tests']]
stmts_json = stmts_to_json(stmts)
save_json_to_s3(stmts_json, bucket, key, save_format)
[docs]def run_model_tests_from_s3(model_name, test_corpus='large_corpus_tests',
upload_results=True, bucket=EMMAA_BUCKET_NAME):
"""Run a given set of tests on a given model, both loaded from S3.
After loading both the model and the set of tests, model/test overlap
is determined using a ScopeTestConnector and tests are run.
Parameters
----------
model_name : str
Name of EmmaaModel to load from S3.
test_corpus : str
Name of the file containing tests on S3.
upload_results : Optional[bool]
Whether to upload test results to S3 in JSON format. Can be set
to False when running tests. Default: True
Returns
-------
emmaa.model_tests.ModelManager
Instance of ModelManager containing the model data, list of applied
tests and the test results.
"""
mm = load_model_manager_from_s3(model_name=model_name, bucket=bucket)
test_dict, _ = load_tests_from_s3(test_corpus, bucket=bucket)
if isinstance(test_dict, dict):
tests = test_dict['tests']
test_data = test_dict['test_data']
elif isinstance(test_dict, list):
tests = test_dict
test_data = None
tm = TestManager([mm], tests)
tc = mm.model.test_config.get('test_connector', 'refinement')
if tc == 'scope':
test_connector = ScopeTestConnector()
elif tc == 'refinement':
test_connector = RefinementTestConnector()
tm.make_tests(test_connector)
filter_func = None
edge_filter_func = None
if mm.model.test_config.get('filters'):
filter_func_name = mm.model.test_config['filters'].get(test_corpus)
if filter_func_name:
filter_func = node_filter_functions.get(filter_func_name)
if mm.model.test_config.get('edge_filters'):
edge_filter_func_name = mm.model.test_config['edge_filters'].get(
test_corpus)
edge_filter_func = edge_filter_functions.get(edge_filter_func_name)
tm.run_tests(filter_func, edge_filter_func)
# Optionally upload test results to S3
if upload_results:
mm.upload_results(test_corpus, test_data, bucket=bucket)
return mm