Persisting and sharing machine learning models
Follow along with a notebook in Google Colab |
This example shows how to train, save, publish, and drop a machine learning model using the Model Catalog.
Setup
For more information on how to get started using Python, refer to the Connecting with Python tutorial.
pip install graphdatascience
# Import the client
from graphdatascience import GraphDataScience
# Replace with the actual URI, username, and password
AURA_CONNECTION_URI = "neo4j+s://xxxxxxxx.databases.neo4j.io"
AURA_USERNAME = "neo4j"
AURA_PASSWORD = ""
# Configure the client with AuraDS-recommended settings
gds = GraphDataScience(
AURA_CONNECTION_URI,
auth=(AURA_USERNAME, AURA_PASSWORD),
aura_ds=True
)
In the following code examples we use the print
function to print Pandas DataFrame
and Series
objects. You can try different ways to print a Pandas object, for instance via the to_string
and to_json
methods; if you use a JSON representation, in some cases you may need to include a default handler to handle Neo4j DateTime
objects. Check the Python connection section for some examples.
For more information on how to get started using the Cypher Shell, refer to the Neo4j Cypher Shell tutorial.
Run the following commands from the directory where the Cypher shell is installed. |
export AURA_CONNECTION_URI="neo4j+s://xxxxxxxx.databases.neo4j.io"
export AURA_USERNAME="neo4j"
export AURA_PASSWORD=""
./cypher-shell -a $AURA_CONNECTION_URI -u $AURA_USERNAME -p $AURA_PASSWORD
For more information on how to get started using Python, refer to the Connecting with Python tutorial.
pip install neo4j
# Import the driver
from neo4j import GraphDatabase
# Replace with the actual URI, username, and password
AURA_CONNECTION_URI = "neo4j+s://xxxxxxxx.databases.neo4j.io"
AURA_USERNAME = "neo4j"
AURA_PASSWORD = ""
# Instantiate the driver
driver = GraphDatabase.driver(
AURA_CONNECTION_URI,
auth=(AURA_USERNAME, AURA_PASSWORD)
)
# Import to prettify results
import json
# Import for the JSON helper function
from neo4j.time import DateTime
# Helper function for serializing Neo4j DateTime in JSON dumps
def default(o):
if isinstance(o, (DateTime)):
return o.isoformat()
Create an example graph
We start by creating some basic graph data first.
gds.run_cypher("""
MERGE (dan:Person:ExampleData {name: 'Dan', age: 20, heightAndWeight: [185, 75]})
MERGE (annie:Person:ExampleData {name: 'Annie', age: 12, heightAndWeight: [124, 42]})
MERGE (matt:Person:ExampleData {name: 'Matt', age: 67, heightAndWeight: [170, 80]})
MERGE (jeff:Person:ExampleData {name: 'Jeff', age: 45, heightAndWeight: [192, 85]})
MERGE (brie:Person:ExampleData {name: 'Brie', age: 27, heightAndWeight: [176, 57]})
MERGE (elsa:Person:ExampleData {name: 'Elsa', age: 32, heightAndWeight: [158, 55]})
MERGE (john:Person:ExampleData {name: 'John', age: 35, heightAndWeight: [172, 76]})
MERGE (dan)-[:KNOWS {relWeight: 1.0}]->(annie)
MERGE (dan)-[:KNOWS {relWeight: 1.6}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 0.1}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 3.0}]->(jeff)
MERGE (annie)-[:KNOWS {relWeight: 1.2}]->(brie)
MERGE (matt)-[:KNOWS {relWeight: 10.0}]->(brie)
MERGE (brie)-[:KNOWS {relWeight: 1.0}]->(elsa)
MERGE (brie)-[:KNOWS {relWeight: 2.2}]->(jeff)
MERGE (john)-[:KNOWS {relWeight: 5.0}]->(jeff)
RETURN True AS exampleDataCreated
""")
MERGE (dan:Person:ExampleData {name: 'Dan', age: 20, heightAndWeight: [185, 75]})
MERGE (annie:Person:ExampleData {name: 'Annie', age: 12, heightAndWeight: [124, 42]})
MERGE (matt:Person:ExampleData {name: 'Matt', age: 67, heightAndWeight: [170, 80]})
MERGE (jeff:Person:ExampleData {name: 'Jeff', age: 45, heightAndWeight: [192, 85]})
MERGE (brie:Person:ExampleData {name: 'Brie', age: 27, heightAndWeight: [176, 57]})
MERGE (elsa:Person:ExampleData {name: 'Elsa', age: 32, heightAndWeight: [158, 55]})
MERGE (john:Person:ExampleData {name: 'John', age: 35, heightAndWeight: [172, 76]})
MERGE (dan)-[:KNOWS {relWeight: 1.0}]->(annie)
MERGE (dan)-[:KNOWS {relWeight: 1.6}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 0.1}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 3.0}]->(jeff)
MERGE (annie)-[:KNOWS {relWeight: 1.2}]->(brie)
MERGE (matt)-[:KNOWS {relWeight: 10.0}]->(brie)
MERGE (brie)-[:KNOWS {relWeight: 1.0}]->(elsa)
MERGE (brie)-[:KNOWS {relWeight: 2.2}]->(jeff)
MERGE (john)-[:KNOWS {relWeight: 5.0}]->(jeff)
RETURN True AS exampleDataCreated
# Cypher query
create_example_graph_on_disk_query = """
MERGE (dan:Person:ExampleData {name: 'Dan', age: 20, heightAndWeight: [185, 75]})
MERGE (annie:Person:ExampleData {name: 'Annie', age: 12, heightAndWeight: [124, 42]})
MERGE (matt:Person:ExampleData {name: 'Matt', age: 67, heightAndWeight: [170, 80]})
MERGE (jeff:Person:ExampleData {name: 'Jeff', age: 45, heightAndWeight: [192, 85]})
MERGE (brie:Person:ExampleData {name: 'Brie', age: 27, heightAndWeight: [176, 57]})
MERGE (elsa:Person:ExampleData {name: 'Elsa', age: 32, heightAndWeight: [158, 55]})
MERGE (john:Person:ExampleData {name: 'John', age: 35, heightAndWeight: [172, 76]})
MERGE (dan)-[:KNOWS {relWeight: 1.0}]->(annie)
MERGE (dan)-[:KNOWS {relWeight: 1.6}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 0.1}]->(matt)
MERGE (annie)-[:KNOWS {relWeight: 3.0}]->(jeff)
MERGE (annie)-[:KNOWS {relWeight: 1.2}]->(brie)
MERGE (matt)-[:KNOWS {relWeight: 10.0}]->(brie)
MERGE (brie)-[:KNOWS {relWeight: 1.0}]->(elsa)
MERGE (brie)-[:KNOWS {relWeight: 2.2}]->(jeff)
MERGE (john)-[:KNOWS {relWeight: 5.0}]->(jeff)
RETURN True AS exampleDataCreated
"""
# Create the driver session
with driver.session() as session:
# Run query
result = session.run(create_example_graph_on_disk_query).data()
# Prettify the result
print(json.dumps(result, indent=2, sort_keys=True))
We then project an in-memory graph from the data just created.
g, result = gds.graph.project(
"example_graph_for_graphsage",
{
"Person": {
"label": "ExampleData",
"properties": ["age", "heightAndWeight"]
}
},
{
"KNOWS": {
"type": "KNOWS",
"orientation": "UNDIRECTED",
"properties": ["relWeight"]
}
}
)
print(result)
CALL gds.graph.project(
'example_graph_for_graphsage',
{
Person: {
label: 'ExampleData',
properties: ['age', 'heightAndWeight']
}
},
{
KNOWS: {
type: 'KNOWS',
orientation: 'UNDIRECTED',
properties: ['relWeight']
}
}
)
# Cypher query
create_example_graph_in_memory_query = """
CALL gds.graph.project(
'example_graph_for_graphsage',
{
Person: {
label: 'ExampleData',
properties: ['age', 'heightAndWeight']
}
},
{
KNOWS: {
type: 'KNOWS',
orientation: 'UNDIRECTED',
properties: ['relWeight']
}
}
)
"""
# Create the driver session
with driver.session() as session:
# Run query
result = session.run(create_example_graph_in_memory_query).data()
# Prettify the result
print(json.dumps(result, indent=2, sort_keys=True))
Train a model
Machine learning algorithms that support the train
mode produce trained models which are stored in the Model Catalog.
Similarly, predict
procedures can use such trained models to produce predictions.
In this example we train a model for the GraphSAGE algorithm using the train
mode.
model, result = gds.beta.graphSage.train(
g,
modelName="example_graph_model_for_graphsage",
featureProperties=["age", "heightAndWeight"],
aggregator="mean",
activationFunction="sigmoid",
sampleSizes=[25, 10]
)
CALL gds.beta.graphSage.train(
'example_graph_for_graphsage',
{
modelName: 'example_graph_model_for_graphsage',
featureProperties: ['age', 'heightAndWeight'],
aggregator: 'mean',
activationFunction: 'sigmoid',
sampleSizes: [25, 10]
}
)
YIELD modelInfo as info
RETURN
info.name as modelName,
info.metrics.didConverge as didConverge,
info.metrics.ranEpochs as ranEpochs,
info.metrics.epochLosses as epochLosses
# Cypher query
train_graph_sage_on_in_memory_graph_query = """
CALL gds.beta.graphSage.train(
'example_graph_for_graphsage',
{
modelName: 'example_graph_model_for_graphsage',
featureProperties: ['age', 'heightAndWeight'],
aggregator: 'mean',
activationFunction: 'sigmoid',
sampleSizes: [25, 10]
}
)
YIELD modelInfo as info
RETURN
info.name as modelName,
info.metrics.didConverge as didConverge,
info.metrics.ranEpochs as ranEpochs,
info.metrics.epochLosses as epochLosses
"""
# Create the driver session
with driver.session() as session:
# Run query
result = session.run(train_graph_sage_on_in_memory_graph_query).data()
# Prettify the result
print(json.dumps(result, indent=2, sort_keys=True))
View the model catalog
We can use the gds.beta.model.list
procedure to get information on all the models currently available in the catalog.
Along with information on the graph schema, the model name, and the training configuration, the result of the call contains the following fields:
-
loaded
: flag denoting if the model is in memory (true
) or available on disk (false
) -
stored
: flag denoting whether the model has been persisted to disk -
shared
: flag denoting whether the model has been published, making it accessible to all users
results = gds.beta.model.list()
print(results)
CALL gds.beta.model.list()
# Cypher query
list_model_catalog_query = """
CALL gds.beta.model.list()
"""
# Create the driver session
with driver.session() as session:
# Run query
results = session.run(list_model_catalog_query).data()
# Prettify the results
print(json.dumps(results, indent=2, sort_keys=True, default=default))
Save a model to disk
The gds.alpha.model.store
procedure can be used to persist a model to disk.
This is useful both to keep models for later reuse and to free up memory.
Not all the models can be saved to disk. A list of the supported models can be found on the GDS manual. If a model cannot be saved to disk, it will be lost when the AuraDS instance is restarted. |
result = gds.alpha.model.store(model)
print(result)
CALL gds.alpha.model.store("example_graph_model_for_graphsage")
# Cypher query
save_graph_sage_model_to_disk_query = """
CALL gds.alpha.model.store("example_graph_model_for_graphsage")
"""
# Create the driver session
with driver.session() as session:
# Run query
result = session.run(save_graph_sage_model_to_disk_query).data()
# Prettify the result
print(json.dumps(result, indent=2, sort_keys=True))
If we list the model catalog again after persisting a model, we can see that the stored
flag for that model has been set to true
.
results = gds.beta.model.list()
print(results)
CALL gds.beta.model.list()
# Cypher query
list_model_catalog_query = """
CALL gds.beta.model.list()
"""
# Create the driver session
with driver.session() as session:
# Run query
results = session.run(list_model_catalog_query).data()
# Prettify the results
print(json.dumps(results, indent=2, sort_keys=True, default=default))
Share a model with other users
After a model has been created, it can be useful to make it available to other users for different use cases.
A model can only be shared with other users of the same AuraDS instance. |
Create a new user
In order to see how this works in practice on AuraDS, we first of all need to create another user to share the model with.
# Switch to the "system" database to run the
# "CREATE USER" admin command
gds.set_database("system")
gds.run_cypher("""
CREATE USER testUser IF NOT EXISTS
SET PASSWORD 'password'
SET PASSWORD CHANGE NOT REQUIRED
""")
:connect system
CREATE USER testUser IF NOT EXISTS
SET PASSWORD 'password'
SET PASSWORD CHANGE NOT REQUIRED
# Cypher query
create_a_new_user_query = """
CREATE USER testUser IF NOT EXISTS
SET PASSWORD 'password'
SET PASSWORD CHANGE NOT REQUIRED
"""
# Create the driver session using the "system" database
with driver.session(database="system") as session:
# Run query
result = session.run(create_a_new_user_query).data()
# Prettify the result
print(json.dumps(result, indent=2, sort_keys=True))
Publish the model
A model can be published (made accessible to other users) using the gds.alpha.model.publish
procedure.
Upon publication, the model name is updated by appending _public
to its original name.
# Switch back to the default "neo4j" database
# to publish the model
gds.set_database("neo4j")
model_public = gds.alpha.model.publish(model)
print(model_public)
:connect neo4j
CALL gds.alpha.model.publish('example_graph_model_for_graphsage')
# Cypher query
publish_graph_sage_model_to_disk_query = """
CALL gds.alpha.model.publish('example_graph_model_for_graphsage')
"""
# Create the driver session
with driver.session() as session:
# Run query
result = session.run(publish_graph_sage_model_to_disk_query).data()
# Prettify the result
print(json.dumps(result, indent=2, sort_keys=True, default=default))
View the model as a different user
In order to verify that the published model is visible to the user we have just created, we need to create a new client (or driver) session.
We can then use it to run the gds.beta.model.list
procedure again under the new user and verify that the model is included in the list.
test_user_gds = GraphDataScience(
AURA_CONNECTION_URI,
auth=("testUser", "password"),
aura_ds=True
)
results = test_user_gds.beta.model.list()
print(results)
// First, open a new Cypher shell with the following command:
//
// ./cypher-shell -a $AURA_CONNECTION_URI -u testUser -p password
CALL gds.beta.model.list()
test_user_driver = GraphDatabase.driver(
AURA_CONNECTION_URI,
auth=("testUser", "password")
)
# Create the driver session
with test_user_driver.session() as session:
# Run query
results = session.run(list_model_catalog_query).data()
# Prettify the results
print(json.dumps(results, indent=2, sort_keys=True, default=default))
Cleanup
The in-memory graphs, the data in the Neo4j database, the models, and the test user can now all be deleted.
# Delete the example dataset
gds.run_cypher("""
MATCH (example:ExampleData)
DETACH DELETE example
""")
# Delete the projected graph from memory
gds.graph.drop(g)
# Drop the model from memory
gds.beta.model.drop(model_public)
# Delete the model from disk
gds.alpha.model.delete(model_public)
# Switch to the "system" database to delete the example user
gds.set_database("system")
gds.run_cypher("""
DROP USER testUser
""")
// Delete the example dataset from the database
MATCH (example:ExampleData)
DETACH DELETE example;
// Delete the projected graph from memory
CALL gds.graph.drop("example_graph_for_graphsage");
// Drop the model from memory
CALL gds.beta.model.drop("example_graph_model_for_graphsage_public");
// Delete the model from disk
CALL gds.alpha.model.delete("example_graph_model_for_graphsage_public");
// Delete the example user
DROP USER testUser;
# Delete the example dataset from the database
delete_example_graph_query = """
MATCH (example:ExampleData)
DETACH DELETE example
"""
# Delete the projected graph from memory
drop_in_memory_graph_query = """
CALL gds.graph.drop("example_graph_for_graphsage")
"""
# Drop the model from memory
drop_example_models_query = """
CALL gds.beta.model.drop("example_graph_model_for_graphsage_public")
"""
# Delete the model from disk
delete_example_models_query = """
CALL gds.alpha.model.delete("example_graph_model_for_graphsage_public")
"""
# Delete the example user
drop_example_user_query = """
DROP USER testUser
"""
# Create the driver session
with driver.session() as session:
# Run queries
print(session.run(delete_example_graph_query).data())
print(session.run(drop_in_memory_graph_query).data())
print(session.run(drop_example_models_query).data())
print(session.run(delete_example_models_query).data())
# Create another driver session on the system database
# to drop the test user
with driver.session(database='system') as session:
print(session.run(drop_example_user_query).data())
driver.close()
test_user_driver.close()
Closing the connection
The connection should always be closed when no longer needed.
Although the GDS client automatically closes the connection when the object is deleted, it is good practice to close it explicitly.
# Close the client connection
gds.close()
# Close the driver connection
driver.close()
References
Cypher
-
Learn more about the Cypher syntax
-
You can use the Cypher Cheat Sheet as a reference of all available Cypher features