Source code for cassandra.datastax.graph.fluent

# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import copy

from concurrent.futures import Future

HAVE_GREMLIN = False
try:
    import gremlin_python
    HAVE_GREMLIN = True
except ImportError:
    # gremlinpython is not installed.
    pass

if HAVE_GREMLIN:
    from gremlin_python.structure.graph import Graph
    from gremlin_python.driver.remote_connection import RemoteConnection, RemoteTraversal
    from gremlin_python.process.traversal import Traverser, TraversalSideEffects
    from gremlin_python.process.graph_traversal import GraphTraversal

    from cassandra.cluster import Session, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT
    from cassandra.datastax.graph import GraphOptions, GraphProtocol
    from cassandra.datastax.graph.query import _GraphSONContextRowFactory

    from cassandra.datastax.graph.fluent.serializers import (
        GremlinGraphSONReaderV2,
        GremlinGraphSONReaderV3,
        dse_graphson2_deserializers,
        gremlin_graphson2_deserializers,
        dse_graphson3_deserializers,
        gremlin_graphson3_deserializers
    )
    from cassandra.datastax.graph.fluent.query import _DefaultTraversalBatch, _query_from_traversal

    log = logging.getLogger(__name__)

    __all__ = ['BaseGraphRowFactory', 'graph_traversal_row_factory',
               'graph_traversal_dse_object_row_factory', 'DSESessionRemoteGraphConnection', 'DseGraph']

    # Traversal result keys
    _bulk_key = 'bulk'
    _result_key = 'result'


[docs] class BaseGraphRowFactory(_GraphSONContextRowFactory): """ Base row factory for graph traversal. This class basically wraps a graphson reader function to handle additional features of Gremlin/DSE and is callable as a normal row factory. Currently supported: - bulk results """ def __call__(self, column_names, rows): for row in rows: parsed_row = self.graphson_reader.readObject(row[0]) yield parsed_row[_result_key] bulk = parsed_row.get(_bulk_key, 1) for _ in range(bulk - 1): yield copy.deepcopy(parsed_row[_result_key])
class _GremlinGraphSON2RowFactory(BaseGraphRowFactory): """Row Factory that returns the decoded graphson2.""" graphson_reader_class = GremlinGraphSONReaderV2 graphson_reader_kwargs = {'deserializer_map': gremlin_graphson2_deserializers} class _DseGraphSON2RowFactory(BaseGraphRowFactory): """Row Factory that returns the decoded graphson2 as DSE types.""" graphson_reader_class = GremlinGraphSONReaderV2 graphson_reader_kwargs = {'deserializer_map': dse_graphson2_deserializers} gremlin_graphson2_traversal_row_factory = _GremlinGraphSON2RowFactory # TODO remove in next major graph_traversal_row_factory = gremlin_graphson2_traversal_row_factory dse_graphson2_traversal_row_factory = _DseGraphSON2RowFactory # TODO remove in next major graph_traversal_dse_object_row_factory = dse_graphson2_traversal_row_factory class _GremlinGraphSON3RowFactory(BaseGraphRowFactory): """Row Factory that returns the decoded graphson2.""" graphson_reader_class = GremlinGraphSONReaderV3 graphson_reader_kwargs = {'deserializer_map': gremlin_graphson3_deserializers} class _DseGraphSON3RowFactory(BaseGraphRowFactory): """Row Factory that returns the decoded graphson3 as DSE types.""" graphson_reader_class = GremlinGraphSONReaderV3 graphson_reader_kwargs = {'deserializer_map': dse_graphson3_deserializers} gremlin_graphson3_traversal_row_factory = _GremlinGraphSON3RowFactory dse_graphson3_traversal_row_factory = _DseGraphSON3RowFactory
[docs] class DSESessionRemoteGraphConnection(RemoteConnection): """ A Tinkerpop RemoteConnection to execute traversal queries on DSE. :param session: A DSE session :param graph_name: (Optional) DSE Graph name. :param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`. """ session = None graph_name = None execution_profile = None def __init__(self, session, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT): super(DSESessionRemoteGraphConnection, self).__init__(None, None) if not isinstance(session, Session): raise ValueError('A DSE Session must be provided to execute graph traversal queries.') self.session = session self.graph_name = graph_name self.execution_profile = execution_profile @staticmethod def _traversers_generator(traversers): for t in traversers: yield Traverser(t) def _prepare_query(self, bytecode): ep = self.session.execution_profile_clone_update(self.execution_profile) graph_options = ep.graph_options graph_options.graph_name = self.graph_name or graph_options.graph_name graph_options.graph_language = DseGraph.DSE_GRAPH_QUERY_LANGUAGE # We resolve the execution profile options here , to know how what gremlin factory to set self.session._resolve_execution_profile_options(ep) context = None if graph_options.graph_protocol == GraphProtocol.GRAPHSON_2_0: row_factory = gremlin_graphson2_traversal_row_factory elif graph_options.graph_protocol == GraphProtocol.GRAPHSON_3_0: row_factory = gremlin_graphson3_traversal_row_factory context = { 'cluster': self.session.cluster, 'graph_name': graph_options.graph_name.decode('utf-8') } else: raise ValueError('Unknown graph protocol: {}'.format(graph_options.graph_protocol)) ep.row_factory = row_factory query = DseGraph.query_from_traversal(bytecode, graph_options.graph_protocol, context) return query, ep @staticmethod def _handle_query_results(result_set, gremlin_future): try: gremlin_future.set_result( RemoteTraversal(DSESessionRemoteGraphConnection._traversers_generator(result_set), TraversalSideEffects()) ) except Exception as e: gremlin_future.set_exception(e) @staticmethod def _handle_query_error(response, gremlin_future): gremlin_future.set_exception(response) def submit(self, bytecode): # the only reason I don't use submitAsync here # is to avoid an unuseful future wrap query, ep = self._prepare_query(bytecode) traversers = self.session.execute_graph(query, execution_profile=ep) return RemoteTraversal(self._traversers_generator(traversers), TraversalSideEffects()) def submitAsync(self, bytecode): query, ep = self._prepare_query(bytecode) # to be compatible with gremlinpython, we need to return a concurrent.futures.Future gremlin_future = Future() response_future = self.session.execute_graph_async(query, execution_profile=ep) response_future.add_callback(self._handle_query_results, gremlin_future) response_future.add_errback(self._handle_query_error, gremlin_future) return gremlin_future def __str__(self): return "<DSESessionRemoteGraphConnection: graph_name='{0}'>".format(self.graph_name) __repr__ = __str__
[docs] class DseGraph(object): """ Dse Graph utility class for GraphTraversal construction and execution. """ DSE_GRAPH_QUERY_LANGUAGE = 'bytecode-json' """ Graph query language, Default is 'bytecode-json' (GraphSON). """ DSE_GRAPH_QUERY_PROTOCOL = GraphProtocol.GRAPHSON_2_0 """ Graph query language, Default is GraphProtocol.GRAPHSON_2_0. """
[docs] @staticmethod def query_from_traversal(traversal, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, context=None): """ From a GraphTraversal, return a query string based on the language specified in `DseGraph.DSE_GRAPH_QUERY_LANGUAGE`. :param traversal: The GraphTraversal object :param graph_protocol: The graph protocol. Default is `DseGraph.DSE_GRAPH_QUERY_PROTOCOL`. :param context: The dict of the serialization context, needed for GraphSON3 (tuple, udt). e.g: {'cluster': cluster, 'graph_name': name} """ if isinstance(traversal, GraphTraversal): for strategy in traversal.traversal_strategies.traversal_strategies: rc = strategy.remote_connection if (isinstance(rc, DSESessionRemoteGraphConnection) and rc.session or rc.graph_name or rc.execution_profile): log.warning("GraphTraversal session, graph_name and execution_profile are " "only taken into account when executed with TinkerPop.") return _query_from_traversal(traversal, graph_protocol, context)
[docs] @staticmethod def traversal_source(session=None, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT, traversal_class=None): """ Returns a TinkerPop GraphTraversalSource binded to the session and graph_name if provided. :param session: (Optional) A DSE session :param graph_name: (Optional) DSE Graph name :param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`. :param traversal_class: (Optional) The GraphTraversalSource class to use (DSL). .. code-block:: python from cassandra.cluster import Cluster from cassandra.datastax.graph.fluent import DseGraph c = Cluster() session = c.connect() g = DseGraph.traversal_source(session, 'my_graph') print g.V().valueMap().toList() """ graph = Graph() traversal_source = graph.traversal(traversal_class) if session: traversal_source = traversal_source.withRemote( DSESessionRemoteGraphConnection(session, graph_name, execution_profile)) return traversal_source
[docs] @staticmethod def create_execution_profile(graph_name, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, **kwargs): """ Creates an ExecutionProfile for GraphTraversal execution. You need to register that execution profile to the cluster by using `cluster.add_execution_profile`. :param graph_name: The graph name :param graph_protocol: (Optional) The graph protocol, default is `DSE_GRAPH_QUERY_PROTOCOL`. """ if graph_protocol == GraphProtocol.GRAPHSON_2_0: row_factory = dse_graphson2_traversal_row_factory elif graph_protocol == GraphProtocol.GRAPHSON_3_0: row_factory = dse_graphson3_traversal_row_factory else: raise ValueError('Unknown graph protocol: {}'.format(graph_protocol)) ep = GraphExecutionProfile(row_factory=row_factory, graph_options=GraphOptions(graph_name=graph_name, graph_language=DseGraph.DSE_GRAPH_QUERY_LANGUAGE, graph_protocol=graph_protocol), **kwargs) return ep
[docs] @staticmethod def batch(*args, **kwargs): """ Returns the :class:`cassandra.datastax.graph.fluent.query.TraversalBatch` object allowing to execute multiple traversals in the same transaction. """ return _DefaultTraversalBatch(*args, **kwargs)