ML Model
The ML Model entity represents trained machine learning models across various ML platforms and frameworks. ML Models can be trained using different algorithms and frameworks (TensorFlow, PyTorch, Scikit-learn, etc.) and deployed to various platforms (MLflow, SageMaker, Vertex AI, etc.).
Identity
ML Models are identified by three pieces of information:
- The platform where the model is registered or deployed: this is the specific ML platform that hosts or manages this model. Examples are
mlflow,sagemaker,vertexai,databricks, etc. See dataplatform for more details. - The name of the model: this is the unique identifier for the model within the platform. The naming convention varies by platform:
- MLflow: typically uses the registered model name (e.g.,
recommendation-model) - SageMaker: uses the model name or model package group name (e.g.,
product-recommendation-v1) - Vertex AI: uses the model resource name (e.g.,
projects/123/locations/us-central1/models/456)
- MLflow: typically uses the registered model name (e.g.,
- The environment or origin where the model was trained: this is similar to the fabric concept for datasets, allowing you to distinguish between models in different environments (PROD, DEV, QA, etc.). The full list of supported environments is available in FabricType.pdl.
An example of an ML Model identifier is urn:li:mlModel:(urn:li:dataPlatform:mlflow,my-recommendation-model,PROD).
Important Capabilities
Basic Model Information
The core information about an ML Model is captured in the mlModelProperties aspect. This includes:
- Name and Description: Human-readable name and description of what the model does
- Model Type: The algorithm or architecture used (e.g., "Convolutional Neural Network", "Random Forest", "BERT")
- Version: Version information using the
versionPropertiesaspect - Timestamps: Created and last modified timestamps
- Custom Properties: Flexible key-value pairs for platform-specific metadata (e.g., framework version, model format)
The following code snippet shows you how to create a basic ML Model:
Python SDK: Create an ML Model
from datahub.metadata.urns import MlModelGroupUrn
from datahub.sdk import DataHubClient
from datahub.sdk.mlmodel import MLModel
client = DataHubClient.from_env()
mlmodel = MLModel(
id="customer-churn-predictor",
name="Customer Churn Prediction Model",
platform="mlflow",
description="A gradient boosting model that predicts customer churn based on usage patterns and engagement metrics",
custom_properties={
"framework": "xgboost",
"framework_version": "1.7.0",
"model_format": "pickle",
},
model_group=MlModelGroupUrn(platform="mlflow", name="customer-churn-models"),
)
client.entities.upsert(mlmodel)
Hyperparameters and Metrics
ML Models can capture both the hyperparameters used during training and various metrics from training and production:
- Hyperparameters: Configuration values that control the training process (learning rate, batch size, number of epochs, etc.)
- Training Metrics: Performance metrics from the training process (accuracy, loss, F1 score, etc.)
- Online Metrics: Performance metrics from production deployment (latency, throughput, drift, etc.)
These are stored in the mlModelProperties aspect as structured lists of parameters and metrics.
Python SDK: Add hyperparameters and metrics to an ML Model
from datahub.metadata.urns import CorpUserUrn, DomainUrn, MlModelUrn, TagUrn
from datahub.sdk import DataHubClient
client = DataHubClient.from_env()
mlmodel = client.entities.get(
MlModelUrn(platform="mlflow", name="customer-churn-predictor")
)
mlmodel.set_hyper_params(
{
"learning_rate": "0.1",
"max_depth": "6",
"n_estimators": "100",
"subsample": "0.8",
"colsample_bytree": "0.8",
}
)
mlmodel.set_training_metrics(
{
"accuracy": "0.87",
"precision": "0.84",
"recall": "0.82",
"f1_score": "0.83",
"auc_roc": "0.91",
}
)
mlmodel.add_owner(CorpUserUrn("data_science_team"))
mlmodel.add_tag(TagUrn("production"))
mlmodel.add_tag(TagUrn("classification"))
mlmodel.set_domain(DomainUrn("urn:li:domain:customer-analytics"))
client.entities.update(mlmodel)
Intended Use and Ethical Considerations
DataHub supports comprehensive model documentation following ML model card best practices. These aspects help stakeholders understand the appropriate use cases and ethical implications of using the model:
- Intended Use (
intendedUseaspect): Documents primary use cases, intended users, and out-of-scope applications - Ethical Considerations (
mlModelEthicalConsiderationsaspect): Documents use of sensitive data, risks and harms, mitigation strategies - Caveats and Recommendations (
mlModelCaveatsAndRecommendationsaspect): Additional considerations, ideal dataset characteristics, and usage recommendations
These aspects align with responsible AI practices and help ensure models are used appropriately.
Training and Evaluation Data
ML Models can document their training and evaluation datasets in two complementary ways:
Direct Dataset References
- Training Data (
mlModelTrainingDataaspect): Datasets used to train the model, including preprocessing information and motivation for dataset selection - Evaluation Data (
mlModelEvaluationDataaspect): Datasets used for model evaluation and testing
Each dataset reference includes the dataset URN, motivation for using that dataset, and any preprocessing steps applied. This creates direct lineage relationships between models and their training data.
Lineage via Training Runs
Training runs (dataProcessInstance entities) provide an alternative and often more detailed way to capture training lineage:
- Training runs declare their input datasets via
dataProcessInstanceInputaspect - Training runs declare their output datasets via
dataProcessInstanceOutputaspect - Models reference training runs via the
trainingJobsfield
This creates indirect lineage: Dataset → Training Run → Model
When to use each approach:
- Use direct dataset references for simple documentation of what data was used
- Use training runs for complete lineage tracking including:
- Multiple training/validation/test datasets
- Metrics and hyperparameters from the training process
- Temporal tracking (when the training occurred)
- Connection to experiments for comparing multiple training attempts
Most production ML systems should use training runs for comprehensive lineage tracking.
Factor Prompts and Quantitative Analysis
For detailed model analysis and performance reporting:
- Factor Prompts (
mlModelFactorPromptsaspect): Factors that may affect model performance (demographic groups, environmental conditions, etc.) - Quantitative Analyses (
mlModelQuantitativeAnalysesaspect): Links to dashboards or reports showing disaggregated performance metrics across different factors - Metrics (
mlModelMetricsaspect): Detailed metrics with descriptions beyond simple training/online metrics
Source Code and Cost
- Source Code (
sourceCodeaspect): Links to model training code, notebooks, or repositories (GitHub, GitLab, etc.) - Cost (
costaspect): Cost attribution information for tracking model training and inference expenses
Training Runs and Experiments
ML Models in DataHub can be linked to their training runs and experiments, providing complete lineage from raw data through training to deployed models.
Training Runs
Training runs represent specific executions of model training jobs. In DataHub, training runs are modeled as dataProcessInstance entities with a specialized subtype:
- Entity Type:
dataProcessInstance - Subtype:
MLAssetSubTypes.MLFLOW_TRAINING_RUN - Key Aspects:
dataProcessInstanceProperties: Basic properties like name, timestamps, and custom propertiesmlTrainingRunProperties: ML-specific properties including:- Training metrics (accuracy, loss, F1 score, etc.)
- Hyperparameters (learning rate, batch size, epochs, etc.)
- Output URLs (model artifacts, checkpoints)
- External URLs (links to training dashboards)
dataProcessInstanceInput: Input datasets used for trainingdataProcessInstanceOutput: Output datasets (predictions, feature importance, etc.)dataProcessInstanceRunEvent: Start, completion, and failure events
Training runs create lineage relationships showing:
- Upstream: Which datasets were used for training
- Downstream: Which models were produced by the training run
Models reference their training runs through the trainingJobs field in mlModelProperties, and model groups can also reference training runs to track all training activity for a model family.
Experiments
Experiments organize related training runs into logical groups, typically representing a series of attempts to optimize a model or compare different approaches. In DataHub, experiments are modeled as container entities:
- Entity Type:
container - Subtype:
MLAssetSubTypes.MLFLOW_EXPERIMENT - Purpose: Group related training runs for organization and comparison
Training runs belong to experiments through the container aspect, creating a hierarchy:
Experiment: "Customer Churn Prediction"
├── Training Run 1: baseline model
├── Training Run 2: with feature engineering
├── Training Run 3: hyperparameter tuning
└── Training Run 4: final production model
This structure mirrors common ML platform patterns (like MLflow's experiment/run hierarchy) and enables:
- Comparing metrics across multiple training attempts
- Tracking the evolution of a model through iterations
- Understanding which approaches were tried and their results
- Organizing training work by project or objective
Python SDK: Create training runs and experiments
import argparse
from datetime import datetime
from dh_ai_client import DatahubAIClient
from datahub.emitter.mcp_builder import (
ContainerKey,
)
from datahub.ingestion.source.common.subtypes import MLAssetSubTypes
from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import RunResultType
from datahub.metadata.schema_classes import (
AuditStampClass,
DataProcessInstancePropertiesClass,
MLHyperParamClass,
MLMetricClass,
MLTrainingRunPropertiesClass,
)
from datahub.metadata.urns import (
CorpUserUrn,
DataProcessInstanceUrn,
GlossaryTermUrn,
TagUrn,
)
from datahub.sdk.container import Container
from datahub.sdk.dataset import Dataset
from datahub.sdk.mlmodel import MLModel
from datahub.sdk.mlmodelgroup import MLModelGroup
parser = argparse.ArgumentParser()
parser.add_argument("--token", required=False, help="DataHub access token")
parser.add_argument(
"--server_url",
required=False,
default="http://localhost:8080",
help="DataHub server URL (defaults to http://localhost:8080)",
)
args = parser.parse_args()
# Initialize client
client = DatahubAIClient(token=args.token, server_url=args.server_url)
# Use a unique prefix for all IDs to avoid conflicts
prefix = "test"
# Define all entity IDs upfront
# Basic entity IDs
basic_model_group_id = f"{prefix}_basic_group"
basic_model_id = f"{prefix}_basic_model"
basic_experiment_id = f"{prefix}_basic_experiment"
basic_run_id = f"{prefix}_basic_run"
basic_dataset_id = f"{prefix}_basic_dataset"
# Advanced entity IDs
advanced_model_group_id = f"{prefix}_airline_forecast_models_group"
advanced_model_id = f"{prefix}_arima_model"
advanced_experiment_id = f"{prefix}_airline_forecast_experiment"
advanced_run_id = f"{prefix}_simple_training_run"
advanced_input_dataset_id = f"{prefix}_iris_input"
advanced_output_dataset_id = f"{prefix}_iris_output"
# Display names with prefix
basic_model_group_name = f"{prefix} Basic Group"
basic_model_name = f"{prefix} Basic Model"
basic_experiment_name = f"{prefix} Basic Experiment"
basic_run_name = f"{prefix} Basic Run"
basic_dataset_name = f"{prefix} Basic Dataset"
advanced_model_group_name = f"{prefix} Airline Forecast Models Group"
advanced_model_name = f"{prefix} ARIMA Model"
advanced_experiment_name = f"{prefix} Airline Forecast Experiment"
advanced_run_name = f"{prefix} Simple Training Run"
advanced_input_dataset_name = f"{prefix} Iris Training Input Data"
advanced_output_dataset_name = f"{prefix} Iris Model Output Data"
def create_basic_model_group():
"""Create a basic model group."""
print("Creating basic model group...")
basic_model_group = MLModelGroup(
id=basic_model_group_id,
platform="mlflow",
name=basic_model_group_name,
)
client._emit_mcps(basic_model_group.as_mcps())
return basic_model_group
def create_advanced_model_group():
"""Create an advanced model group."""
print("Creating advanced model group...")
advanced_model_group = MLModelGroup(
id=advanced_model_group_id,
platform="mlflow",
name=advanced_model_group_name,
description="Group of models for airline passenger forecasting",
created=datetime.now(),
last_modified=datetime.now(),
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
external_url="https://www.linkedin.com/in/datahub",
tags=["urn:li:tag:forecasting", "urn:li:tag:arima"],
terms=["urn:li:glossaryTerm:forecasting"],
custom_properties={"team": "forecasting"},
)
client._emit_mcps(advanced_model_group.as_mcps())
return advanced_model_group
def create_basic_model():
"""Create a basic model."""
print("Creating basic model...")
basic_model = MLModel(
id=basic_model_id,
platform="mlflow",
name=basic_model_name,
)
client._emit_mcps(basic_model.as_mcps())
return basic_model
def create_advanced_model():
"""Create an advanced model."""
print("Creating advanced model...")
advanced_model = MLModel(
id=advanced_model_id,
platform="mlflow",
name=advanced_model_name,
description="ARIMA model for airline passenger forecasting",
created=datetime.now(),
last_modified=datetime.now(),
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
external_url="https://www.linkedin.com/in/datahub",
tags=["urn:li:tag:forecasting", "urn:li:tag:arima"],
terms=["urn:li:glossaryTerm:forecasting"],
custom_properties={"team": "forecasting"},
version="1",
aliases=["champion"],
hyper_params={"learning_rate": "0.01"},
training_metrics={"accuracy": "0.9"},
)
client._emit_mcps(advanced_model.as_mcps())
return advanced_model
def create_basic_experiment():
"""Create a basic experiment."""
print("Creating basic experiment...")
basic_experiment = Container(
container_key=ContainerKey(platform="mlflow", name=basic_experiment_id),
display_name=basic_experiment_name,
)
client._emit_mcps(basic_experiment.as_mcps())
return basic_experiment
def create_advanced_experiment():
"""Create an advanced experiment."""
print("Creating advanced experiment...")
advanced_experiment = Container(
container_key=ContainerKey(platform="mlflow", name=advanced_experiment_id),
display_name=advanced_experiment_name,
description="Experiment to forecast airline passenger numbers",
extra_properties={"team": "forecasting"},
created=datetime(2025, 4, 9, 22, 30),
last_modified=datetime(2025, 4, 9, 22, 30),
subtype=MLAssetSubTypes.MLFLOW_EXPERIMENT,
)
client._emit_mcps(advanced_experiment.as_mcps())
return advanced_experiment
def create_basic_training_run():
"""Create a basic training run."""
print("Creating basic training run...")
basic_run_urn = client.create_training_run(
run_id=basic_run_id,
run_name=basic_run_name,
)
return basic_run_urn
def create_advanced_training_run():
"""Create an advanced training run."""
print("Creating advanced training run...")
advanced_run_urn = client.create_training_run(
run_id=advanced_run_id,
properties=DataProcessInstancePropertiesClass(
name=advanced_run_name,
created=AuditStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
customProperties={"team": "forecasting"},
),
training_run_properties=MLTrainingRunPropertiesClass(
id=advanced_run_id,
outputUrls=["s3://my-bucket/output"],
trainingMetrics=[MLMetricClass(name="accuracy", value="0.9")],
hyperParams=[MLHyperParamClass(name="learning_rate", value="0.01")],
externalUrl="https:localhost:5000",
),
run_result=RunResultType.FAILURE,
start_timestamp=1628580000000,
end_timestamp=1628580001000,
)
return advanced_run_urn
def create_basic_dataset():
"""Create a basic dataset."""
print("Creating basic dataset...")
basic_input_dataset = Dataset(
platform="snowflake",
name=basic_dataset_id,
display_name=basic_dataset_name,
)
client._emit_mcps(basic_input_dataset.as_mcps())
return basic_input_dataset
def create_advanced_datasets():
"""Create advanced datasets."""
print("Creating advanced datasets...")
advanced_input_dataset = Dataset(
platform="snowflake",
name=advanced_input_dataset_id,
description="Raw Iris dataset used for training ML models",
schema=[("id", "number"), ("name", "string"), ("species", "string")],
display_name=advanced_input_dataset_name,
tags=["urn:li:tag:ml_data", "urn:li:tag:iris"],
terms=["urn:li:glossaryTerm:raw_data"],
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
custom_properties={
"data_source": "UCI Repository",
"records": "150",
"features": "4",
},
)
client._emit_mcps(advanced_input_dataset.as_mcps())
advanced_output_dataset = Dataset(
platform="snowflake",
name=advanced_output_dataset_id,
description="Processed Iris dataset with model predictions",
schema=[("id", "number"), ("name", "string"), ("species", "string")],
display_name=advanced_output_dataset_name,
tags=["urn:li:tag:ml_data", "urn:li:tag:predictions"],
terms=["urn:li:glossaryTerm:model_output"],
owners=[CorpUserUrn("urn:li:corpuser:datahub")],
custom_properties={
"model_version": "1.0",
"records": "150",
"accuracy": "0.95",
},
)
client._emit_mcps(advanced_output_dataset.as_mcps())
return advanced_input_dataset, advanced_output_dataset
# Split relationship functions into individual top-level functions
def add_model_to_model_group(model, model_group):
"""Add model to model group relationship."""
print("Adding model to model group...")
model.set_model_group(model_group.urn)
client._emit_mcps(model.as_mcps())
def add_run_to_experiment(run_urn, experiment):
"""Add run to experiment relationship."""
print("Adding run to experiment...")
client.add_run_to_experiment(run_urn=run_urn, experiment_urn=str(experiment.urn))
def add_run_to_model(model, run_id):
"""Add run to model relationship."""
print("Adding run to model...")
model.add_training_job(DataProcessInstanceUrn(run_id))
client._emit_mcps(model.as_mcps())
def add_run_to_model_group(model_group, run_id):
"""Add run to model group relationship."""
print("Adding run to model group...")
model_group.add_training_job(DataProcessInstanceUrn(run_id))
client._emit_mcps(model_group.as_mcps())
def add_input_dataset_to_run(run_urn, input_dataset):
"""Add input dataset to run relationship."""
print("Adding input dataset to run...")
client.add_input_datasets_to_run(
run_urn=run_urn, dataset_urns=[str(input_dataset.urn)]
)
def add_output_dataset_to_run(run_urn, output_dataset):
"""Add output dataset to run relationship."""
print("Adding output dataset to run...")
client.add_output_datasets_to_run(
run_urn=run_urn, dataset_urns=[str(output_dataset.urn)]
)
def update_model_properties(model):
"""Update model properties."""
print("Updating model properties...")
# Update model version
model.set_version("2")
# Add tags and terms
model.add_tag(TagUrn("marketing"))
model.add_term(GlossaryTermUrn("marketing"))
# Add version alias
model.add_version_alias("challenger")
# Save the changes
client._emit_mcps(model.as_mcps())
def update_model_group_properties(model_group):
"""Update model group properties."""
print("Updating model group properties...")
# Update description
model_group.set_description("Updated description for airline forecast models")
# Add tags and terms
model_group.add_tag(TagUrn("production"))
model_group.add_term(GlossaryTermUrn("time-series"))
# Update custom properties
model_group.set_custom_properties(
{"team": "forecasting", "business_unit": "operations", "status": "active"}
)
# Save the changes
client._emit_mcps(model_group.as_mcps())
def update_experiment_properties():
"""Update experiment properties."""
print("Updating experiment properties...")
# Create a container object for the existing experiment
existing_experiment = Container(
container_key=ContainerKey(platform="mlflow", name=advanced_experiment_id),
display_name=advanced_experiment_name,
)
# Update properties
existing_experiment.set_description(
"Updated experiment for forecasting passenger numbers"
)
existing_experiment.add_tag(TagUrn("time-series"))
existing_experiment.add_term(GlossaryTermUrn("forecasting"))
existing_experiment.set_custom_properties(
{"team": "forecasting", "priority": "high", "status": "active"}
)
# Save the changes
client._emit_mcps(existing_experiment.as_mcps())
def main():
# Parse arguments
print("Creating AI assets...")
# Comment in/out the functions you want to run
# Create basic entities
create_basic_model_group()
create_basic_model()
create_basic_experiment()
create_basic_training_run()
create_basic_dataset()
# Create advanced entities
advanced_model_group = create_advanced_model_group()
advanced_model = create_advanced_model()
advanced_experiment = create_advanced_experiment()
advanced_run_urn = create_advanced_training_run()
advanced_input_dataset, advanced_output_dataset = create_advanced_datasets()
# # Create relationships - each can be commented out independently
add_model_to_model_group(advanced_model, advanced_model_group)
add_run_to_experiment(advanced_run_urn, advanced_experiment)
add_run_to_model(advanced_model, advanced_run_id)
add_run_to_model_group(advanced_model_group, advanced_run_id)
add_input_dataset_to_run(advanced_run_urn, advanced_input_dataset)
add_output_dataset_to_run(advanced_run_urn, advanced_output_dataset)
# # Update properties - each can be commented out independently
update_model_properties(advanced_model)
update_model_group_properties(advanced_model_group)
update_experiment_properties()
print("All done! AI entities created successfully.")
if __name__ == "__main__":
main()
Relationships and Lineage
ML Models support rich relationship modeling through various aspects and fields:
Core Relationships
Model Groups (via
groupsfield inmlModelProperties): Models can belong tomlModelGroupentities, creating aMemberOfrelationship. This organizes related models into logical families or collections.Training Runs (via
trainingJobsfield inmlModelProperties): Models referencedataProcessInstanceentities withMLFLOW_TRAINING_RUNsubtype that produced them. This creates upstream lineage showing:- Which training run created this model
- What datasets were used for training (via the training run's input datasets)
- What hyperparameters and metrics were recorded
- Which experiment the training run belonged to
Features (via
mlFeaturesfield inmlModelProperties): Models can consumemlFeatureentities, creating aConsumesrelationship. This documents:- Which features are required for model inference
- The complete feature set used during training
- Dependencies on feature stores or feature tables
Deployments (via
deploymentsfield inmlModelProperties): Models can be deployed tomlModelDeploymententities, representing running model endpoints in various environments (production, staging, etc.)Training Datasets (via
mlModelTrainingDataaspect): Direct references to datasets used for training, including preprocessing information and motivation for dataset selectionEvaluation Datasets (via
mlModelEvaluationDataaspect): References to datasets used for model evaluation and testing
Lineage Graph Structure
These relationships create a comprehensive lineage graph:
Training Datasets → Training Run → ML Model → ML Model Deployment
↓
Experiment
Feature Tables → ML Features → ML Model
ML Model Group ← ML Model
This enables powerful queries such as:
- "Show me all datasets that influenced this model's predictions"
- "Which models will be affected if this dataset schema changes?"
- "What's the full history of training runs that created versions of this model?"
- "Which production endpoints are serving this model?"
Python SDK: Update model-specific aspects
import datahub.metadata.schema_classes as models
from datahub.metadata.urns import DatasetUrn, MlModelUrn
from datahub.sdk import DataHubClient
client = DataHubClient.from_env()
model_urn = MlModelUrn(platform="mlflow", name="customer-churn-predictor")
mlmodel = client.entities.get(model_urn)
intended_use = models.IntendedUseClass(
primaryUses=[
"Predict customer churn to enable proactive retention campaigns",
"Identify high-risk customers for targeted interventions",
],
primaryUsers=[models.IntendedUserTypeClass.ENTERPRISE],
outOfScopeUses=[
"Not suitable for real-time predictions (batch inference only)",
"Not trained on international markets outside North America",
],
)
mlmodel._set_aspect(intended_use)
training_data = models.TrainingDataClass(
trainingData=[
models.BaseDataClass(
dataset=str(
DatasetUrn(
platform="snowflake", name="prod.analytics.customer_features"
)
),
motivation="Historical customer data with confirmed churn labels",
preProcessing=[
"Removed customers with less than 30 days of history",
"Standardized numerical features using StandardScaler",
"One-hot encoded categorical variables",
],
)
]
)
mlmodel._set_aspect(training_data)
source_code = models.SourceCodeClass(
sourceCode=[
models.SourceCodeUrlClass(
type=models.SourceCodeUrlTypeClass.ML_MODEL_SOURCE_CODE,
sourceCodeUrl="https://github.com/example/ml-models/tree/main/churn-predictor",
)
]
)
mlmodel._set_aspect(source_code)
ethical_considerations = models.EthicalConsiderationsClass(
data=["Model uses demographic data (age, location) which may be sensitive"],
risksAndHarms=[
"Predictions may disproportionately affect certain customer segments",
"False positives could lead to unnecessary retention spending",
],
mitigations=[
"Regular bias audits conducted quarterly",
"Human review required for high-value customer interventions",
],
)
mlmodel._set_aspect(ethical_considerations)
client.entities.update(mlmodel)
print(f"Updated aspects for model: {model_urn}")
Tags, Terms, and Ownership
Like other DataHub entities, ML Models support:
- Tags (
globalTagsaspect): Flexible categorization (e.g., "pii-model", "production-ready", "experimental") - Glossary Terms (
glossaryTermsaspect): Business concepts (e.g., "Customer Churn", "Fraud Detection") - Ownership (
ownershipaspect): Individuals or teams responsible for the model (data scientists, ML engineers, etc.) - Domains (
domainsaspect): Organizational grouping (e.g., "Recommendations", "Risk Management")
Complete ML Workflow Example
The following example demonstrates a complete ML model lifecycle in DataHub, showing how all the pieces work together:
1. Create Model Group
↓
2. Create Experiment (Container)
↓
3. Create Training Run (DataProcessInstance)
├── Link input datasets
├── Link output datasets
└── Add metrics and hyperparameters
↓
4. Create Model
├── Set version and aliases
├── Link to model group
├── Link to training run
├── Add hyperparameters and metrics
└── Add ownership and tags
↓
5. Link Training Run to Experiment
↓
6. Update Model properties as needed
├── Change version aliases (champion → challenger)
├── Add additional tags/terms
└── Update metrics from production
This workflow creates rich lineage showing:
- Which datasets trained the model
- What experiments and training runs were involved
- How the model evolved through versions
- Which version is deployed (via aliases)
- Who owns and maintains the model
Complete Python Example: Full ML Workflow
See the comprehensive example in /metadata-ingestion/examples/ai/dh_ai_docs_demo.py which demonstrates:
- Creating model groups with metadata
- Creating experiments to organize training runs
- Creating training runs with metrics, hyperparameters, and dataset lineage
- Creating models with versions and aliases
- Linking all entities together to form complete lineage
- Updating properties and managing the model lifecycle
The example shows both basic patterns for getting started and advanced patterns for production ML systems.
Code Examples
Querying ML Model Information
The standard REST APIs can be used to retrieve ML Model entities and their aspects:
Python: Query an ML Model via REST API
import urllib.parse
import requests
gms_server = "http://localhost:8080"
model_urn = "urn:li:mlModel:(urn:li:dataPlatform:mlflow,customer-churn-predictor,PROD)"
encoded_urn = urllib.parse.quote(model_urn, safe="")
response = requests.get(f"{gms_server}/entities/{encoded_urn}")
if response.status_code == 200:
entity = response.json()
print(f"Entity URN: {entity['urn']}")
print("\nAspects:")
if "mlModelProperties" in entity["aspects"]:
props = entity["aspects"]["mlModelProperties"]
print(f" Name: {props.get('name')}")
print(f" Description: {props.get('description')}")
print(f" Type: {props.get('type')}")
if props.get("hyperParams"):
print("\n Hyperparameters:")
for param in props["hyperParams"]:
print(f" - {param['name']}: {param['value']}")
if props.get("trainingMetrics"):
print("\n Training Metrics:")
for metric in props["trainingMetrics"]:
print(f" - {metric['name']}: {metric['value']}")
if "globalTags" in entity["aspects"]:
tags = entity["aspects"]["globalTags"]["tags"]
print(f"\n Tags: {[tag['tag'] for tag in tags]}")
if "ownership" in entity["aspects"]:
owners = entity["aspects"]["ownership"]["owners"]
print(f"\n Owners: {[owner['owner'] for owner in owners]}")
if "intendedUse" in entity["aspects"]:
intended = entity["aspects"]["intendedUse"]
print(f"\n Primary Uses: {intended.get('primaryUses')}")
print(f" Out of Scope Uses: {intended.get('outOfScopeUses')}")
else:
print(f"Failed to fetch entity: {response.status_code}")
print(response.text)
Integration Points
Related Entities
ML Models integrate with several other entities in the DataHub metadata model:
- mlModelGroup: Logical grouping of related model versions (e.g., all versions of a recommendation model)
- mlModelDeployment: Running instances of deployed models with status, endpoint URLs, and deployment metadata
- mlFeature: Individual features consumed by the model for inference
- mlFeatureTable: Collections of features, often from feature stores
- dataset: Training and evaluation datasets used by the model
- dataProcessInstance (with
MLFLOW_TRAINING_RUNsubtype): Specific training runs that created model versions, including metrics, hyperparameters, and lineage to input/output datasets - container (with
MLFLOW_EXPERIMENTsubtype): Experiments that organize related training runs for a model or project - versionSet: Groups all versions of a model together for version management
GraphQL Resolvers
The GraphQL API provides rich querying capabilities for ML Models through resolvers in datahub-graphql-core/src/main/java/com/linkedin/datahub/graphql/types/mlmodel/. These resolvers support:
- Fetching model details with all aspects
- Navigating relationships to features, groups, and deployments
- Searching and filtering models by tags, terms, platform, etc.
Ingestion Sources
Several ingestion sources automatically extract ML Model metadata:
- MLflow: Extracts registered models, versions, metrics, parameters, and lineage from MLflow tracking servers
- SageMaker: Ingests models, model packages, and endpoints from AWS SageMaker
- Vertex AI: Extracts models and endpoints from Google Cloud Vertex AI
- Databricks: Ingests MLflow models from Databricks workspaces
- Unity Catalog: Extracts ML models registered in Unity Catalog
These sources are located in /metadata-ingestion/src/datahub/ingestion/source/ and automatically populate model properties, relationships, and lineage.
Notable Exceptions
Model Versioning
ML Model versioning in DataHub uses the versionProperties aspect, which provides a robust framework for tracking model versions across their lifecycle. This is the standard approach demonstrated in production ML platforms.
Version Properties Aspect
Every ML Model should use the versionProperties aspect, which includes:
- version: A
VersionTagClasscontaining the version identifier (e.g., "1", "2", "v1.0.0") - versionSet: A URN that groups all versions of a model together (e.g.,
urn:li:versionSet:(mlModel,mlmodel_my-model_versions)) - sortId: A string used for ordering versions (typically the version number zero-padded)
- aliases: Optional array of
VersionTagClassobjects for named version references
Version Aliases for A/B Testing
Version aliases enable flexible model lifecycle management and A/B testing workflows. Common aliases include:
- "champion": The currently deployed production model
- "challenger": A candidate model being tested or evaluated
- "baseline": A reference model for performance comparison
- "latest": The most recently trained version
These aliases allow you to reference models by their role rather than specific version numbers, enabling smooth model promotion workflows:
Model v1 (alias: "champion") # Currently in production
Model v2 (alias: "challenger") # Being tested in canary deployment
Model v3 (alias: "latest") # Just completed training
When v2 proves superior, you can update aliases without changing infrastructure:
Model v1 (no alias) # Retired
Model v2 (alias: "champion") # Promoted to production
Model v3 (alias: "challenger") # Now being tested
Model Groups and Versioning
Model groups (mlModelGroup entities) serve as logical containers for organizing related models. While model groups can contain multiple versions of the same model, versioning is handled through the versionProperties aspect on individual models, not through the group structure itself. Model groups are used for:
- Organizing all versions of a model family
- Grouping experimental variants or different architectures solving the same problem
- Managing lineage and metadata common across multiple related models
The relationship between models and model groups is through the groups field in mlModelProperties, creating a MemberOf relationship.
Platform-Specific Naming
Different ML platforms have different naming conventions:
- MLflow: Uses a two-level hierarchy (registered model name + version number). In DataHub, each version can be a separate entity, or versions can be tracked in a single entity.
- SageMaker: Has multiple model concepts (model, model package, model package group). DataHub can model these as separate entities or consolidate them.
- Vertex AI: Uses fully qualified resource names. These should be simplified to human-readable names when possible.
When ingesting from these platforms, connectors handle platform-specific naming and convert it to appropriate DataHub URNs.
Model Cards
The various aspects (intendedUse, mlModelFactorPrompts, mlModelEthicalConsiderations, etc.) follow the Model Cards for Model Reporting framework (Mitchell et al., 2019). While these aspects are optional, they are strongly recommended for production models to ensure responsible AI practices and transparent model documentation.
Technical Reference
For technical details about fields, searchability, and relationships, view the Columns tab in DataHub.