Graph-Based Learning
from typing import Dict, List, Any, Optional, Tuple import sqlparse from collections import defaultdict, Counter import json from datetime import datetime
from ..modules.knowledge_graph import KnowledgeGraph from ..core.config import settings
class GraphLearningService: """Advanced graph-based learning for SQL query patterns and relationships."""
def __init__(self, workspace_id: str):
self.workspace_id = workspace_id
self.kg = KnowledgeGraph(workspace_id)
async def analyze_and_store_query(
self,
query: str,
execution_time: float,
result_count: int,
success: bool = True,
user_context: Optional[Dict] = None
):
"""Analyze query and store comprehensive graph relationships."""
# Parse SQL query
parsed = sqlparse.parse(query)[0]
analysis = self._analyze_query_structure(parsed)
# Generate unique query ID
query_id = f"query_{hash(query) % 100000}"
# Store basic query metadata
self.kg.add_relation(query_id, "has_sql", query)
self.kg.add_relation(query_id, "execution_time", str(execution_time))
self.kg.add_relation(query_id, "result_count", str(result_count))
self.kg.add_relation(query_id, "success", str(success))
self.kg.add_relation(query_id, "timestamp", datetime.utcnow().isoformat())
# Store query type and patterns
self.kg.add_relation(query_id, "query_type", analysis["type"])
for pattern in analysis["patterns"]:
self.kg.add_relation(query_id, "uses_pattern", pattern)
# Store table relationships
for table in analysis["tables"]:
self.kg.add_relation(query_id, "uses_table", table)
self.kg.add_relation(table, "used_by_query", query_id)
# Store column relationships
for column in analysis["columns"]:
self.kg.add_relation(query_id, "uses_column", column)
self.kg.add_relation(column, "used_by_query", query_id)
# Store join relationships
for join_info in analysis["joins"]:
self.kg.add_relation(
join_info["left_table"],
f"joined_with_{join_info['type']}",
join_info["right_table"]
)
self.kg.add_relation(query_id, "performs_join", f"{join_info['left_table']}_{join_info['right_table']}")
# Store aggregation patterns
for agg in analysis["aggregations"]:
self.kg.add_relation(query_id, "uses_aggregation", agg["function"])
self.kg.add_relation(query_id, "aggregates_column", agg["column"])
# Store filter patterns
for filter_info in analysis["filters"]:
self.kg.add_relation(query_id, "filters_on", filter_info["column"])
self.kg.add_relation(query_id, "uses_operator", filter_info["operator"])
# Store user context if provided
if user_context:
user_id = user_context.get("user_id", "unknown")
department = user_context.get("department")
self.kg.add_relation(query_id, "written_by", user_id)
if department:
self.kg.add_relation(user_id, "works_in", department)
self.kg.add_relation(query_id, "department_context", department)
# Learn performance patterns
self._learn_performance_patterns(query_id, analysis, execution_time, result_count)
# Learn co-occurrence patterns
self._learn_co_occurrence_patterns(analysis)
def _analyze_query_structure(self, parsed_query) -> Dict[str, Any]:
"""Deep analysis of SQL query structure."""
analysis = {
"type": self._get_query_type(parsed_query),
"patterns": [],
"tables": set(),
"columns": set(),
"joins": [],
"aggregations": [],
"filters": [],
"subqueries": 0,
"complexity_score": 0
}
# Walk through all tokens
for token in parsed_query.flatten():
if token.ttype is sqlparse.tokens.Keyword:
keyword = token.value.upper()
# Identify patterns
if keyword in ["JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN", "FULL JOIN"]:
analysis["patterns"].append(f"{keyword.replace(' ', '_')}_PATTERN")
elif keyword in ["GROUP BY", "ORDER BY", "HAVING"]:
analysis["patterns"].append(f"{keyword.replace(' ', '_')}_PATTERN")
elif keyword in ["UNION", "INTERSECT", "EXCEPT"]:
analysis["patterns"].append(f"{keyword}_PATTERN")
elif keyword == "WITH":
analysis["patterns"].append("CTE_PATTERN")
elif token.ttype is sqlparse.tokens.Name:
# Extract table and column names
name = token.value.lower()
# Simple heuristic: if it contains a dot, it's likely table.column
if "." in name:
table, column = name.split(".", 1)
analysis["tables"].add(table)
analysis["columns"].add(f"{table}.{column}")
else:
# Could be table or column - context-dependent
analysis["tables"].add(name)
analysis["columns"].add(name)
# Detect aggregation functions
query_str = str(parsed_query).upper()
agg_functions = ["COUNT", "SUM", "AVG", "MAX", "MIN", "STDDEV", "VARIANCE"]
for func in agg_functions:
if func in query_str:
analysis["patterns"].append(f"{func}_AGGREGATION")
# Try to extract the column being aggregated
import re
pattern = f"{func}\\s*\\(\\s*([^)]+)\\s*\\)"
matches = re.findall(pattern, query_str)
for match in matches:
analysis["aggregations"].append({
"function": func,
"column": match.strip()
})
# Calculate complexity score
analysis["complexity_score"] = self._calculate_complexity(analysis, query_str)
return analysis
def _get_query_type(self, parsed_query) -> str:
"""Determine the primary query type."""
first_keyword = None
for token in parsed_query.tokens:
if token.ttype is sqlparse.tokens.Keyword:
first_keyword = token.value.upper()
break
return first_keyword or "UNKNOWN"
def _calculate_complexity(self, analysis: Dict, query_str: str) -> int:
"""Calculate query complexity score."""
score = 0
# Base score for query type
if analysis["type"] == "SELECT":
score += 1
elif analysis["type"] in ["INSERT", "UPDATE", "DELETE"]:
score += 2
# Add points for patterns
score += len(analysis["patterns"]) * 2
score += len(analysis["tables"]) * 1
score += len(analysis["joins"]) * 3
score += len(analysis["aggregations"]) * 2
# Subqueries add significant complexity
subquery_count = query_str.count("SELECT") - 1 # Subtract main query
score += subquery_count * 5
# Window functions
if "OVER" in query_str:
score += 4
return score
def _learn_performance_patterns(
self,
query_id: str,
analysis: Dict,
execution_time: float,
result_count: int
):
"""Learn relationships between query patterns and performance."""
# Classify performance
if execution_time < 100: # Fast queries
performance_class = "FAST"
elif execution_time < 1000: # Medium queries
performance_class = "MEDIUM"
else: # Slow queries
performance_class = "SLOW"
self.kg.add_relation(query_id, "performance_class", performance_class)
# Learn pattern-performance relationships
for pattern in analysis["patterns"]:
self.kg.add_relation(pattern, "observed_performance", f"{performance_class}_{execution_time}")
# Learn table-performance relationships
for table in analysis["tables"]:
self.kg.add_relation(table, "query_performance", f"{performance_class}_{execution_time}")
def _learn_co_occurrence_patterns(self, analysis: Dict):
"""Learn which tables, columns, and patterns frequently appear together."""
tables = list(analysis["tables"])
patterns = analysis["patterns"]
# Table co-occurrence
for i, table1 in enumerate(tables):
for table2 in tables[i+1:]:
self.kg.add_relation(table1, "co_occurs_with", table2)
self.kg.add_relation(table2, "co_occurs_with", table1)
# Pattern-table relationships
for pattern in patterns:
for table in tables:
self.kg.add_relation(pattern, "commonly_used_with", table)
async def get_intelligent_suggestions(
self,
partial_query: str,
context: Optional[Dict] = None
) -> List[Dict[str, Any]]:
"""Get intelligent suggestions based on graph learning."""
# Parse what we have so far
try:
parsed = sqlparse.parse(partial_query)[0]
current_analysis = self._analyze_query_structure(parsed)
except:
current_analysis = {"tables": set(), "patterns": [], "columns": set()}
suggestions = []
# 1. Table completion suggestions
if current_analysis["tables"]:
table_suggestions = self._suggest_related_tables(current_analysis["tables"])
suggestions.extend(table_suggestions)
# 2. Join suggestions
join_suggestions = self._suggest_joins(current_analysis["tables"])
suggestions.extend(join_suggestions)
# 3. Column suggestions
column_suggestions = self._suggest_columns(current_analysis["tables"])
suggestions.extend(column_suggestions)
# 4. Pattern-based suggestions
pattern_suggestions = self._suggest_patterns(current_analysis)
suggestions.extend(pattern_suggestions)
# 5. Performance-based suggestions
perf_suggestions = self._suggest_performance_optimizations(current_analysis)
suggestions.extend(perf_suggestions)
return suggestions[:10] # Return top 10 suggestions
def _suggest_related_tables(self, current_tables: set) -> List[Dict[str, Any]]:
"""Suggest tables that are frequently used with current tables."""
suggestions = []
for table in current_tables:
# Find tables that co-occur with this one
related_tables = self.kg.get_relations(table, "co_occurs_with")
for related_table in related_tables[:3]: # Top 3
if related_table not in current_tables:
suggestions.append({
"type": "table",
"suggestion": f"JOIN {related_table}",
"reason": f"Frequently used with {table}",
"confidence": 0.8
})
return suggestions
def _suggest_joins(self, current_tables: set) -> List[Dict[str, Any]]:
"""Suggest join patterns based on learned relationships."""
suggestions = []
tables_list = list(current_tables)
for i, table1 in enumerate(tables_list):
for table2 in tables_list[i+1:]:
# Check if there are known join patterns
join_patterns = self.kg.get_relations(table1, f"joined_with_LEFT")
if table2 in join_patterns:
suggestions.append({
"type": "join",
"suggestion": f"LEFT JOIN {table2} ON {table1}.id = {table2}.{table1}_id",
"reason": f"Common join pattern observed",
"confidence": 0.9
})
return suggestions
def _suggest_columns(self, current_tables: set) -> List[Dict[str, Any]]:
"""Suggest columns frequently used with current tables."""
suggestions = []
for table in current_tables:
# Find commonly used columns for this table
common_columns = self.kg.get_relations(table, "commonly_selected")
for column in common_columns[:5]: # Top 5
suggestions.append({
"type": "column",
"suggestion": f"{table}.{column}",
"reason": f"Frequently selected from {table}",
"confidence": 0.7
})
return suggestions
def _suggest_patterns(self, current_analysis: Dict) -> List[Dict[str, Any]]:
"""Suggest query patterns based on current context."""
suggestions = []
current_patterns = set(current_analysis["patterns"])
# If they're using GROUP BY, suggest ORDER BY
if "GROUP_BY_PATTERN" in current_patterns:
if "ORDER_BY_PATTERN" not in current_patterns:
suggestions.append({
"type": "pattern",
"suggestion": "ORDER BY column_name DESC",
"reason": "Commonly used with GROUP BY",
"confidence": 0.8
})
# If they have aggregations, suggest HAVING
if any("AGGREGATION" in p for p in current_patterns):
if "HAVING" not in current_patterns:
suggestions.append({
"type": "pattern",
"suggestion": "HAVING COUNT(*) > 10",
"reason": "Filter aggregated results",
"confidence": 0.7
})
return suggestions
def _suggest_performance_optimizations(self, current_analysis: Dict) -> List[Dict[str, Any]]:
"""Suggest performance optimizations based on learned patterns."""
suggestions = []
# Check if similar queries were slow
for table in current_analysis["tables"]:
slow_queries = self.kg.get_relations(table, "query_performance")
slow_count = sum(1 for perf in slow_queries if "SLOW" in perf)
if slow_count > 2: # If multiple slow queries on this table
suggestions.append({
"type": "performance",
"suggestion": f"Consider adding LIMIT clause for {table}",
"reason": f"Large result sets observed for {table}",
"confidence": 0.6
})
return suggestions
def get_workspace_insights(self) -> Dict[str, Any]:
"""Get comprehensive insights about the workspace."""
# Get all relations for analysis
all_relations = self.kg.get_all_relations()
# Analyze patterns
table_usage = Counter()
pattern_usage = Counter()
performance_stats = defaultdict(list)
for source, relation, target in all_relations:
if relation == "uses_table":
table_usage[target] += 1
elif relation == "uses_pattern":
pattern_usage[target] += 1
elif relation == "performance_class":
performance_stats[target].append(source)
return {
"most_used_tables": dict(table_usage.most_common(10)),
"common_patterns": dict(pattern_usage.most_common(10)),
"performance_distribution": {
"fast": len(performance_stats["FAST"]),
"medium": len(performance_stats["MEDIUM"]),
"slow": len(performance_stats["SLOW"])
},
"total_queries_analyzed": len([r for r in all_relations if r[1] == "has_sql"]),
"unique_tables": len(table_usage),
"complexity_trends": self._analyze_complexity_trends(all_relations)
}
def _analyze_complexity_trends(self, relations: List) -> Dict[str, Any]:
"""Analyze how query complexity changes over time."""
# This would analyze complexity scores over time
# For now, return a simple summary
return {
"avg_complexity": 5.2,
"trend": "increasing", # Could be "increasing", "stable", "decreasing"
"most_complex_pattern": "JOIN with multiple aggregations"
}