Source code for emmaa.priors.literature_prior

"""This module implements the LiteraturePrior class which automates
some of the steps involved in starting a model around a set of
literature searches. Example:

.. code:: python

    lp = LiteraturePrior('some_disease', 'Some Disease',
                         'This is a self-updating model of Some Disease',
                         search_strings=['some disease'],
                         assembly_config_template='nf')
    estmts = lp.get_statements()
    model = lp.make_model(estmts, upload_to_s3=True)

"""

from typing import List

import tqdm
import logging
import datetime
from collections import defaultdict
from indra.util import batch_iter
from indra_db import get_db
from indra_db.util import distill_stmts
from indra_db.client.principal import get_raw_stmt_jsons_from_papers
from indra.databases import mesh_client
from indra.statements import stmts_from_json
from . import SearchTerm
from emmaa.model import EmmaaModel
from emmaa.statements import EmmaaStatement


logger = logging.getLogger(__name__)


class LiteraturePrior:
    def __init__(self, name, human_readable_name, description,
                 search_strings=None, mesh_ids=None,
                 assembly_config_template=None):
        """A class to construct a literature-based prior for an EMMAA model.

        Parameters
        ----------
        name : str
            The model name by which the model will be identified on S3.
        human_readable_name : str
            The human-readable display name for the model.
        description : str
            A human-readable description for the model.
        search_strings : list of str
            A list of search strings e.g., "diabetes" to find papers in the
            literature.
        mesh_ids : list of str
            A list of MeSH IDs that are used to search the literature as
            headings associated with papers.
        assembly_config_template : Optional[str]
            The name of another model from which the initial assembly
            configuration should be adopted.
        """
        self.name = name
        self.human_readable_name = human_readable_name
        self.description = description
        self.search_terms = \
            make_search_terms(search_strings or [], mesh_ids or [])
        if assembly_config_template:
            self.assembly_config = \
                self.get_config_from(assembly_config_template)
        else:
            self.assembly_config = {}
        self.stmts = []

    def get_statements(self, mode='all', batch_size=100):
        """Return EMMAA Statements for this prior's literature set.

        Parameters
        ----------
        mode : 'all' or 'distilled'
            The 'distilled' mode makes sure that the "best", non-redundant
            set of raw statements are found across potentially redundant text
            contents and reader versions. The 'all' mode doesn't do such
            distillation but is significantly faster.
        batch_size : Optional[int]
            Determines how many PMIDs to fetch statements for in each
            iteration. Default: 100.

        Returns
        -------
        list of EmmaaStatement
            A list of EMMAA Statements corresponding to extractions from
            the subset of literature defined by this prior's search terms.
        """
        if self.stmts:
            return self.stmts
        terms_to_pmids = \
            EmmaaModel.search_pubmed(search_terms=self.search_terms,
                                     date_limit=None)
        pmids_to_terms = defaultdict(list)
        for term, pmids in terms_to_pmids.items():
            for pmid in pmids:
                pmids_to_terms[pmid].append(term)
        pmids_to_terms = dict(pmids_to_terms)
        all_pmids = set(pmids_to_terms.keys())
        raw_statements_by_pmid = \
            get_raw_statements_for_pmids(all_pmids, mode=mode,
                                         batch_size=batch_size)
        timestamp = datetime.datetime.now()
        for pmid, stmts in raw_statements_by_pmid.items():
            for stmt in stmts:
                self.stmts.append(EmmaaStatement(stmt, timestamp,
                                                 pmids_to_terms[pmid],
                                                 {'internal': True}))
        return self.stmts

    def get_config_from(self, assembly_config_template):
        """Return assembly config given a template model's name.

        Parameters
        ----------
        assembly_config_template : str
            The name of a model whose assembly config should be adopted.

        Returns
        -------
        dict
            The assembly config of the given template model.
        """
        from emmaa.model import load_config_from_s3
        config = load_config_from_s3(assembly_config_template)
        return config.get('assembly')

    def make_config(self, upload_to_s3=False):
        """Return a config dict fot the model, optionally upload to S3.

        Parameters
        ----------
        upload_to_s3 : Optional[bool]
            If True, the config is uploaded to S3 in the EMMAA bucket.
            Default: False

        Returns
        -------
        dict
            A config data structure.
        """
        config = {
            # These are provided by the user upon initialization
            'name': self.name,
            'human_readable_name': self.human_readable_name,
            'description': self.description,
            # We don't make tests by default
            'make_tests': False,
            # We run daily updates by default
            'run_daily_update': True,
            # We first show the model just on dev
            'dev_only': True,
            # These are the search terms constructed upon
            # initialization
            'search_terms': [st.to_json()
                             for st in self.search_terms],
            # We configure the large corpus tests by default
            'test': {
                'statement_checking': {
                    'max_path_length': 10,
                    'max_paths': 1
                },
                'mc_types': [
                    'signed_graph', 'unsigned_graph'
                ],
                'make_links': True,
                'test_corpus': ['large_corpus_tests'],
                'default_test_corpus': 'large_corpus_tests',
                'filters': {
                    'large_corpus_tests': 'filter_chem_mesh_go'
                }
            }
        }
        # This is adopted from the template specified upon
        # initialization
        if self.assembly_config:
            config["assembly"] = self.assembly_config

        if upload_to_s3:
            from emmaa.model import save_config_to_s3
            save_config_to_s3(self.name, config)
        return config

    def make_model(self, estmts, upload_to_s3=False):
        """Return, and optionally upload to S3 an initial EMMAA Model.

        Parameters
        ----------
        estmts : list of emmaa.statement.EmmaaStatement
            A list of prior EMMAA Statements to initialize the model with.
        upload_to_s3 : Optional[bool]
            If True, the model and the config are uploaded to S3, otherwise
            the model object is just returned without upload. Default: False

        Returns
        -------
        emmaa.model.EmmaaModel
            The EMMAA Model object constructed from the generated config
            and the given EMMAA Statements.
        """
        from emmaa.model import EmmaaModel
        config = self.make_config(upload_to_s3=upload_to_s3)
        model = EmmaaModel(name=self.name, config=config)
        model.add_statements(estmts)
        if upload_to_s3:
            model.save_to_s3()
        return model


[docs]def get_raw_statements_for_pmids(pmids, mode='all', batch_size=100): """Return EmmaaStatements based on extractions from given PMIDs. Parameters ---------- pmids : set or list of str A set of PMIDs to find raw INDRA Statements for in the INDRA DB. mode : 'all' or 'distilled' The 'distilled' mode makes sure that the "best", non-redundant set of raw statements are found across potentially redundant text contents and reader versions. The 'all' mode doesn't do such distillation but is significantly faster. batch_size : Optional[int] Determines how many PMIDs to fetch statements for in each iteration. Default: 100. Returns ------- dict A dict keyed by PMID with values INDRA Statements obtained from the given PMID. """ db = get_db('primary') logger.info(f'Getting raw statements for {len(pmids)} PMIDs') all_stmts = defaultdict(list) for pmid_batch in tqdm.tqdm(batch_iter(pmids, return_func=set, batch_size=batch_size), total=len(pmids)/batch_size): if mode == 'distilled': clauses = [ db.TextRef.pmid.in_(pmid_batch), db.TextContent.text_ref_id == db.TextRef.id, db.Reading.text_content_id == db.TextContent.id, db.RawStatements.reading_id == db.Reading.id] distilled_stmts = distill_stmts(db, get_full_stmts=True, clauses=clauses) for stmt in distilled_stmts: all_stmts[stmt.evidence[0].pmid].append(stmt) else: id_stmts = \ get_raw_stmt_jsons_from_papers(pmid_batch, id_type='pmid', db=db) for pmid, stmt_jsons in id_stmts.items(): all_stmts[pmid] += stmts_from_json(stmt_jsons) all_stmts = dict(all_stmts) return all_stmts
[docs]def make_search_terms( search_strings: List[str], mesh_ids: List[str], ) -> List[SearchTerm]: """Return EMMAA SearchTerms based on search strings and MeSH IDs. Parameters ---------- search_strings : A list of search strings e.g., "diabetes" to find papers in the literature. mesh_ids : A list of MeSH IDs that are used to search the literature as headings associated with papers. Returns ------- : A list of EMMAA SearchTerm objects constructed from the search strings and the MeSH IDs. """ if not search_strings and not mesh_ids: raise ValueError("Need at least one of search_strings or mesh_ids") search_terms = [] for search_string in search_strings: search_term = SearchTerm(type='other', name=search_string, db_refs={}, search_term=search_string) search_terms.append(search_term) for mesh_id in mesh_ids: mesh_name = mesh_client.get_mesh_name(mesh_id) suffix = 'mh' if mesh_id.startswith('D') else 'nm' search_term = SearchTerm(type='mesh', name=mesh_name, db_refs={'MESH': mesh_id}, search_term=f'{mesh_name} [{suffix}]') search_terms.append(search_term) return search_terms