Custom PySpark Graph: All cust_data Nodes
This example shows a custom PySpark GraphReduce graph that uses all tables in
tests/data/cust_data:
custordersorder_productsnotificationsnotification_interactionsnotification_interaction_types
It includes custom definitions across do_annotate, do_filters,
do_normalize, do_reduce, and parent post-join logic.
Key behaviors:
- Customer name-length annotation via
length(coalesce(name, '')) - Order-level amount casting and order spend rollups
- Order-product distinct product count rollups
- Notification interaction engagement rollups driven by interaction-type signals
- Parent-level post-join activity features
Complete Example
#!/usr/bin/env python
"""Custom PySpark GraphReduce example using all cust_data nodes.
Usage:
python examples/custom_pyspark_all_nodes.py
"""
from __future__ import annotations
import datetime
import os
try:
from pyspark.sql import SparkSession, functions as F
except Exception as exc: # pragma: no cover
print(f"pyspark not available: {exc}", flush=True)
raise SystemExit(0)
from graphreduce.enum import ComputeLayerEnum
from graphreduce.graph_reduce import GraphReduce
from graphreduce.node import GraphReduceNode
DATA_PATH = "tests/data/cust_data"
def _safe_numeric(df, candidates: list[str]):
for c in candidates:
if c in df.columns:
return F.coalesce(F.col(c).cast("double"), F.lit(0.0))
return F.lit(0.0)
class CustNode(GraphReduceNode):
def do_annotate(self):
self.df = self.df.withColumn(
self.colabbr("name_length"),
F.length(F.coalesce(F.col(self.colabbr("name")), F.lit(""))),
)
return self.df
def do_filters(self):
self.df = self.df.filter(F.col(self.colabbr("id")).isNotNull())
return self.df
def do_normalize(self):
self.df = self.df.withColumn(
self.colabbr("name"),
F.lower(F.trim(F.coalesce(F.col(self.colabbr("name")), F.lit("")))),
)
return self.df
def do_reduce(self, reduce_key):
return self.df
def do_labels(self, reduce_key):
return self.df
def do_post_join_annotate(self):
ord_ct = F.coalesce(F.col("ord_num_orders"), F.lit(0))
not_ct = F.coalesce(F.col("not_num_notifications"), F.lit(0))
engaged = F.coalesce(F.col("not_num_engaged_interactions"), F.lit(0))
self.df = self.df.withColumn("cust_total_events", ord_ct + not_ct + engaged).withColumn(
"cust_activity_tier",
F.when(F.col("cust_total_events") >= 5, F.lit("high"))
.when(F.col("cust_total_events") >= 2, F.lit("medium"))
.otherwise(F.lit("low")),
)
return self.df
def do_post_join_filters(self):
self.df = self.df.filter(F.col("cust_total_events") >= 0)
return self.df
class OrderNode(GraphReduceNode):
def do_annotate(self):
self.df = self.df.withColumn(self.colabbr("amount_dbl"), F.col(self.colabbr("amount")).cast("double"))
return self.df
def do_filters(self):
self.df = self.df.filter(F.col(self.colabbr("ts")) >= F.lit("2022-01-01"))
return self.df
def do_normalize(self):
self.df = self.df.withColumn(self.colabbr("amount_dbl"), F.coalesce(F.col(self.colabbr("amount_dbl")), F.lit(0.0)))
return self.df
def do_reduce(self, reduce_key):
return (
self.prep_for_features()
.groupBy(self.colabbr(reduce_key))
.agg(
F.countDistinct(F.col(self.colabbr(self.pk))).alias(self.colabbr("num_orders")),
F.sum(F.col(self.colabbr("amount_dbl"))).alias(self.colabbr("sum_amount")),
F.avg(F.col(self.colabbr("amount_dbl"))).alias(self.colabbr("avg_amount")),
)
)
def do_labels(self, reduce_key):
return self.df
def do_post_join_annotate(self):
return self.df
def do_post_join_filters(self):
return self.df
class OrderProductsNode(GraphReduceNode):
def do_annotate(self):
self.df = self.df.withColumn(
self.colabbr("product_id_int"),
F.col(self.colabbr("product_id")).cast("int"),
)
return self.df
def do_filters(self):
self.df = self.df.filter(F.col(self.colabbr("product_id_int")).isNotNull())
return self.df
def do_normalize(self):
return self.df
def do_reduce(self, reduce_key):
return (
self.prep_for_features()
.groupBy(self.colabbr(reduce_key))
.agg(
F.count(F.col(self.colabbr(self.pk))).alias(self.colabbr("num_order_products")),
F.countDistinct(F.col(self.colabbr("product_id_int"))).alias(self.colabbr("num_distinct_products")),
)
)
def do_labels(self, reduce_key):
return self.df
def do_post_join_annotate(self):
return self.df
def do_post_join_filters(self):
return self.df
class NotificationNode(GraphReduceNode):
def do_annotate(self):
self.df = self.df.withColumn(self.colabbr("ts_month"), F.date_format(F.col(self.colabbr("ts")), "MM"))
return self.df
def do_filters(self):
self.df = self.df.filter(F.col(self.colabbr("ts")) >= F.lit("2022-01-01"))
return self.df
def do_normalize(self):
return self.df
def do_reduce(self, reduce_key):
prepped = self.prep_for_features()
ni_num_interactions = _safe_numeric(prepped, ["ni_num_interactions"])
ni_num_engaged = _safe_numeric(prepped, ["ni_num_engaged_interactions"])
return (
prepped
.groupBy(self.colabbr(reduce_key))
.agg(
F.countDistinct(F.col(self.colabbr(self.pk))).alias(self.colabbr("num_notifications")),
F.max(F.col(self.colabbr("ts"))).alias(self.colabbr("max_notification_ts")),
F.sum(ni_num_interactions).alias(self.colabbr("num_interactions")),
F.sum(ni_num_engaged).alias(self.colabbr("num_engaged_interactions")),
)
)
def do_labels(self, reduce_key):
return self.df
def do_post_join_annotate(self):
return self.df
def do_post_join_filters(self):
return self.df
class NotificationInteractionsNode(GraphReduceNode):
def do_annotate(self):
self.df = self.df.withColumn(self.colabbr("ts_day"), F.date_format(F.col(self.colabbr("ts")), "yyyy-MM-dd"))
return self.df
def do_filters(self):
self.df = self.df.filter(F.col(self.colabbr("ts")) >= F.lit("2022-01-01"))
return self.df
def do_normalize(self):
return self.df
def do_reduce(self, reduce_key):
prepped = self.prep_for_features()
engaged_expr = _safe_numeric(
prepped,
[
"ni_nit_is_engagement_type",
"nit_is_engagement_type",
"ni_is_engagement_type",
],
)
return (
prepped
.groupBy(self.colabbr(reduce_key))
.agg(
F.count(F.col(self.colabbr(self.pk))).alias(self.colabbr("num_interactions")),
F.countDistinct(F.col(self.colabbr("interaction_type_id"))).alias(self.colabbr("num_interaction_types")),
F.sum(engaged_expr).alias(self.colabbr("num_engaged_interactions")),
)
)
def do_labels(self, reduce_key):
return self.df
def do_post_join_annotate(self):
return self.df
def do_post_join_filters(self):
return self.df
class NotificationInteractionTypeNode(GraphReduceNode):
def do_annotate(self):
self.df = self.df.withColumn(self.colabbr("name"), F.lower(F.trim(F.col(self.colabbr("name")))))
return self.df
def do_filters(self):
self.df = self.df.filter(F.col(self.colabbr("id")).isNotNull())
return self.df
def do_normalize(self):
return self.df
def do_reduce(self, reduce_key):
return (
self.prep_for_features()
.groupBy(self.colabbr(reduce_key))
.agg(
F.first(F.col(self.colabbr("name")), ignorenulls=True).alias(self.colabbr("name")),
F.max(
F.when(F.col(self.colabbr("name")).isin("clicked", "dismissed"), F.lit(1)).otherwise(F.lit(0))
).alias(self.colabbr("is_engagement_type")),
F.max(
F.when(F.col(self.colabbr("name")).contains("view"), F.lit(1)).otherwise(F.lit(0))
).alias(self.colabbr("is_view_event")),
)
)
def do_labels(self, reduce_key):
return self.df
def do_post_join_annotate(self):
return self.df
def do_post_join_filters(self):
return self.df
def build_custom_pyspark_graph(spark: SparkSession) -> GraphReduce:
for table in [
"cust",
"orders",
"order_products",
"notifications",
"notification_interactions",
"notification_interaction_types",
]:
spark.read.option("header", True).option("inferSchema", True).csv(
os.path.join(DATA_PATH, f"{table}.csv")
).createOrReplaceTempView(table)
cust = CustNode(
fpath="cust",
fmt="sql",
prefix="cust",
pk="id",
compute_layer=ComputeLayerEnum.spark,
columns=["id", "name"],
spark_sqlctx=spark,
)
orders = OrderNode(
fpath="orders",
fmt="sql",
prefix="ord",
pk="id",
date_key="ts",
compute_layer=ComputeLayerEnum.spark,
columns=["id", "customer_id", "ts", "amount"],
spark_sqlctx=spark,
)
order_products = OrderProductsNode(
fpath="order_products",
fmt="sql",
prefix="op",
pk="id",
compute_layer=ComputeLayerEnum.spark,
columns=["id", "order_id", "product_id"],
spark_sqlctx=spark,
)
notifications = NotificationNode(
fpath="notifications",
fmt="sql",
prefix="not",
pk="id",
date_key="ts",
compute_layer=ComputeLayerEnum.spark,
columns=["id", "customer_id", "ts"],
spark_sqlctx=spark,
)
notification_interactions = NotificationInteractionsNode(
fpath="notification_interactions",
fmt="sql",
prefix="ni",
pk="id",
date_key="ts",
compute_layer=ComputeLayerEnum.spark,
columns=["id", "notification_id", "interaction_type_id", "ts"],
spark_sqlctx=spark,
)
interaction_types = NotificationInteractionTypeNode(
fpath="notification_interaction_types",
fmt="sql",
prefix="nit",
pk="id",
compute_layer=ComputeLayerEnum.spark,
columns=["id", "name"],
spark_sqlctx=spark,
)
gr = GraphReduce(
name="custom_pyspark_all_nodes",
parent_node=cust,
compute_layer=ComputeLayerEnum.spark,
spark_sqlctx=spark,
cut_date=datetime.datetime(2023, 6, 30),
)
for node in [cust, orders, order_products, notifications, notification_interactions, interaction_types]:
gr.add_node(node)
gr.add_entity_edge(cust, orders, parent_key="id", relation_key="customer_id", reduce=True)
gr.add_entity_edge(orders, order_products, parent_key="id", relation_key="order_id", reduce=True)
gr.add_entity_edge(cust, notifications, parent_key="id", relation_key="customer_id", reduce=True)
gr.add_entity_edge(notifications, notification_interactions, parent_key="id", relation_key="notification_id", reduce=True)
gr.add_entity_edge(
notification_interactions,
interaction_types,
parent_key="interaction_type_id",
relation_key="id",
reduce=True,
)
return gr
def run_custom_pyspark_all_nodes(spark: SparkSession | None = None):
local_spark = spark or SparkSession.builder.appName("graphreduce-custom-all-nodes").getOrCreate()
gr = build_custom_pyspark_graph(local_spark)
gr.do_transformations()
return gr.parent_node.df
def main() -> None:
df = run_custom_pyspark_all_nodes()
print("rows:", df.count(), flush=True)
print("columns:", len(df.columns), flush=True)
print("column_names:", df.columns, flush=True)
if __name__ == "__main__":
main()
Full runnable script:
examples/custom_pyspark_all_nodes.py
Run Interactive
Idle