# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import copy
from concurrent.futures import Future
HAVE_GREMLIN = False
try:
import gremlin_python
HAVE_GREMLIN = True
except ImportError:
# gremlinpython is not installed.
pass
if HAVE_GREMLIN:
from gremlin_python.structure.graph import Graph
from gremlin_python.driver.remote_connection import RemoteConnection, RemoteTraversal
from gremlin_python.process.traversal import Traverser, TraversalSideEffects
from gremlin_python.process.graph_traversal import GraphTraversal
from cassandra.cluster import Session, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT
from cassandra.datastax.graph import GraphOptions, GraphProtocol
from cassandra.datastax.graph.query import _GraphSONContextRowFactory
from cassandra.datastax.graph.fluent.serializers import (
GremlinGraphSONReaderV2,
GremlinGraphSONReaderV3,
dse_graphson2_deserializers,
gremlin_graphson2_deserializers,
dse_graphson3_deserializers,
gremlin_graphson3_deserializers
)
from cassandra.datastax.graph.fluent.query import _DefaultTraversalBatch, _query_from_traversal
log = logging.getLogger(__name__)
__all__ = ['BaseGraphRowFactory', 'graph_traversal_row_factory',
'graph_traversal_dse_object_row_factory', 'DSESessionRemoteGraphConnection', 'DseGraph']
# Traversal result keys
_bulk_key = 'bulk'
_result_key = 'result'
[docs] class BaseGraphRowFactory(_GraphSONContextRowFactory):
"""
Base row factory for graph traversal. This class basically wraps a
graphson reader function to handle additional features of Gremlin/DSE
and is callable as a normal row factory.
Currently supported:
- bulk results
"""
def __call__(self, column_names, rows):
for row in rows:
parsed_row = self.graphson_reader.readObject(row[0])
yield parsed_row[_result_key]
bulk = parsed_row.get(_bulk_key, 1)
for _ in range(bulk - 1):
yield copy.deepcopy(parsed_row[_result_key])
class _GremlinGraphSON2RowFactory(BaseGraphRowFactory):
"""Row Factory that returns the decoded graphson2."""
graphson_reader_class = GremlinGraphSONReaderV2
graphson_reader_kwargs = {'deserializer_map': gremlin_graphson2_deserializers}
class _DseGraphSON2RowFactory(BaseGraphRowFactory):
"""Row Factory that returns the decoded graphson2 as DSE types."""
graphson_reader_class = GremlinGraphSONReaderV2
graphson_reader_kwargs = {'deserializer_map': dse_graphson2_deserializers}
gremlin_graphson2_traversal_row_factory = _GremlinGraphSON2RowFactory
# TODO remove in next major
graph_traversal_row_factory = gremlin_graphson2_traversal_row_factory
dse_graphson2_traversal_row_factory = _DseGraphSON2RowFactory
# TODO remove in next major
graph_traversal_dse_object_row_factory = dse_graphson2_traversal_row_factory
class _GremlinGraphSON3RowFactory(BaseGraphRowFactory):
"""Row Factory that returns the decoded graphson2."""
graphson_reader_class = GremlinGraphSONReaderV3
graphson_reader_kwargs = {'deserializer_map': gremlin_graphson3_deserializers}
class _DseGraphSON3RowFactory(BaseGraphRowFactory):
"""Row Factory that returns the decoded graphson3 as DSE types."""
graphson_reader_class = GremlinGraphSONReaderV3
graphson_reader_kwargs = {'deserializer_map': dse_graphson3_deserializers}
gremlin_graphson3_traversal_row_factory = _GremlinGraphSON3RowFactory
dse_graphson3_traversal_row_factory = _DseGraphSON3RowFactory
[docs] class DSESessionRemoteGraphConnection(RemoteConnection):
"""
A Tinkerpop RemoteConnection to execute traversal queries on DSE.
:param session: A DSE session
:param graph_name: (Optional) DSE Graph name.
:param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`.
"""
session = None
graph_name = None
execution_profile = None
def __init__(self, session, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT):
super(DSESessionRemoteGraphConnection, self).__init__(None, None)
if not isinstance(session, Session):
raise ValueError('A DSE Session must be provided to execute graph traversal queries.')
self.session = session
self.graph_name = graph_name
self.execution_profile = execution_profile
@staticmethod
def _traversers_generator(traversers):
for t in traversers:
yield Traverser(t)
def _prepare_query(self, bytecode):
ep = self.session.execution_profile_clone_update(self.execution_profile)
graph_options = ep.graph_options
graph_options.graph_name = self.graph_name or graph_options.graph_name
graph_options.graph_language = DseGraph.DSE_GRAPH_QUERY_LANGUAGE
# We resolve the execution profile options here , to know how what gremlin factory to set
self.session._resolve_execution_profile_options(ep)
context = None
if graph_options.graph_protocol == GraphProtocol.GRAPHSON_2_0:
row_factory = gremlin_graphson2_traversal_row_factory
elif graph_options.graph_protocol == GraphProtocol.GRAPHSON_3_0:
row_factory = gremlin_graphson3_traversal_row_factory
context = {
'cluster': self.session.cluster,
'graph_name': graph_options.graph_name.decode('utf-8')
}
else:
raise ValueError('Unknown graph protocol: {}'.format(graph_options.graph_protocol))
ep.row_factory = row_factory
query = DseGraph.query_from_traversal(bytecode, graph_options.graph_protocol, context)
return query, ep
@staticmethod
def _handle_query_results(result_set, gremlin_future):
try:
gremlin_future.set_result(
RemoteTraversal(DSESessionRemoteGraphConnection._traversers_generator(result_set), TraversalSideEffects())
)
except Exception as e:
gremlin_future.set_exception(e)
@staticmethod
def _handle_query_error(response, gremlin_future):
gremlin_future.set_exception(response)
def submit(self, bytecode):
# the only reason I don't use submitAsync here
# is to avoid an unuseful future wrap
query, ep = self._prepare_query(bytecode)
traversers = self.session.execute_graph(query, execution_profile=ep)
return RemoteTraversal(self._traversers_generator(traversers), TraversalSideEffects())
def submitAsync(self, bytecode):
query, ep = self._prepare_query(bytecode)
# to be compatible with gremlinpython, we need to return a concurrent.futures.Future
gremlin_future = Future()
response_future = self.session.execute_graph_async(query, execution_profile=ep)
response_future.add_callback(self._handle_query_results, gremlin_future)
response_future.add_errback(self._handle_query_error, gremlin_future)
return gremlin_future
def __str__(self):
return "<DSESessionRemoteGraphConnection: graph_name='{0}'>".format(self.graph_name)
__repr__ = __str__
[docs] class DseGraph(object):
"""
Dse Graph utility class for GraphTraversal construction and execution.
"""
DSE_GRAPH_QUERY_LANGUAGE = 'bytecode-json'
"""
Graph query language, Default is 'bytecode-json' (GraphSON).
"""
DSE_GRAPH_QUERY_PROTOCOL = GraphProtocol.GRAPHSON_2_0
"""
Graph query language, Default is GraphProtocol.GRAPHSON_2_0.
"""
[docs] @staticmethod
def query_from_traversal(traversal, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, context=None):
"""
From a GraphTraversal, return a query string based on the language specified in `DseGraph.DSE_GRAPH_QUERY_LANGUAGE`.
:param traversal: The GraphTraversal object
:param graph_protocol: The graph protocol. Default is `DseGraph.DSE_GRAPH_QUERY_PROTOCOL`.
:param context: The dict of the serialization context, needed for GraphSON3 (tuple, udt).
e.g: {'cluster': cluster, 'graph_name': name}
"""
if isinstance(traversal, GraphTraversal):
for strategy in traversal.traversal_strategies.traversal_strategies:
rc = strategy.remote_connection
if (isinstance(rc, DSESessionRemoteGraphConnection) and
rc.session or rc.graph_name or rc.execution_profile):
log.warning("GraphTraversal session, graph_name and execution_profile are "
"only taken into account when executed with TinkerPop.")
return _query_from_traversal(traversal, graph_protocol, context)
[docs] @staticmethod
def traversal_source(session=None, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT,
traversal_class=None):
"""
Returns a TinkerPop GraphTraversalSource binded to the session and graph_name if provided.
:param session: (Optional) A DSE session
:param graph_name: (Optional) DSE Graph name
:param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`.
:param traversal_class: (Optional) The GraphTraversalSource class to use (DSL).
.. code-block:: python
from cassandra.cluster import Cluster
from cassandra.datastax.graph.fluent import DseGraph
c = Cluster()
session = c.connect()
g = DseGraph.traversal_source(session, 'my_graph')
print g.V().valueMap().toList()
"""
graph = Graph()
traversal_source = graph.traversal(traversal_class)
if session:
traversal_source = traversal_source.withRemote(
DSESessionRemoteGraphConnection(session, graph_name, execution_profile))
return traversal_source
[docs] @staticmethod
def create_execution_profile(graph_name, graph_protocol=DSE_GRAPH_QUERY_PROTOCOL, **kwargs):
"""
Creates an ExecutionProfile for GraphTraversal execution. You need to register that execution profile to the
cluster by using `cluster.add_execution_profile`.
:param graph_name: The graph name
:param graph_protocol: (Optional) The graph protocol, default is `DSE_GRAPH_QUERY_PROTOCOL`.
"""
if graph_protocol == GraphProtocol.GRAPHSON_2_0:
row_factory = dse_graphson2_traversal_row_factory
elif graph_protocol == GraphProtocol.GRAPHSON_3_0:
row_factory = dse_graphson3_traversal_row_factory
else:
raise ValueError('Unknown graph protocol: {}'.format(graph_protocol))
ep = GraphExecutionProfile(row_factory=row_factory,
graph_options=GraphOptions(graph_name=graph_name,
graph_language=DseGraph.DSE_GRAPH_QUERY_LANGUAGE,
graph_protocol=graph_protocol),
**kwargs)
return ep
[docs] @staticmethod
def batch(*args, **kwargs):
"""
Returns the :class:`cassandra.datastax.graph.fluent.query.TraversalBatch` object allowing to
execute multiple traversals in the same transaction.
"""
return _DefaultTraversalBatch(*args, **kwargs)