Sharding and Partitioning

As applications grow and data volumes explode, single database instances eventually hit performance and storage limits. Database sharding and partitioning are fundamental techniques for breaking through these barriers by distributing data across multiple systems.

Partitioning refers to dividing a database into smaller, more manageable pieces within the same database instance. Think of it as organizing a massive library into different sections: fiction, non-fiction, and reference materials, all within the same building.

Sharding takes this concept further by distributing data across completely separate database instances, often on different servers. This is like creating multiple library branches across different cities, each holding a portion of the total collection.

Why Sharding and Partitioning Matter

Scalability Challenges:

  • Single servers have finite CPU, memory, and storage capacity
  • Network bandwidth becomes a bottleneck with high traffic
  • Backup and recovery times increase exponentially with data size

Performance Benefits:

  • Parallel processing across multiple systems
  • Reduced contention for system resources
  • Improved query response times through smaller data sets
  • Geographic distribution for reduced latency

Availability Improvements:

  • Fault isolation - failure of one shard doesn’t affect others
  • Easier maintenance with smaller, independent systems
  • Reduced blast radius during incidents

1. Understanding Partitioning vs Sharding

Vertical Partitioning

Splits tables by columns, placing different columns on different storage systems:

Original Table: Users
┌─────────┬──────────┬─────────┬──────────────┬─────────────┐
│ user_id │   name   │  email  │   profile    │ preferences │
├─────────┼──────────┼─────────┼──────────────┼─────────────┤
│    1    │   Alice  │ a@x.com │ {large blob} │ {settings}  │
│    2    │    Bob   │ b@x.com │ {large blob} │ {settings}  │
└─────────┴──────────┴─────────┴──────────────┴─────────────┘

After Vertical Partitioning:
Primary Store:                    Secondary Store:
┌─────────┬──────────┬─────────┐  ┌─────────┬──────────────┬─────────────┐
│ user_id │   name   │  email  │  │ user_id │   profile    │ preferences │
├─────────┼──────────┼─────────┤  ├─────────┼──────────────┼─────────────┤
│    1    │   Alice  │ a@x.com │  │    1    │ {large blob} │ {settings}  │
│    2    │    Bob   │ b@x.com │  │    2    │ {large blob} │ {settings}  │
└─────────┴──────────┴─────────┘  └─────────┴──────────────┴─────────────┘

Horizontal Partitioning (Sharding)

Splits tables by rows, distributing different rows across different systems:

Original Table: Orders (10M records)
┌──────────┬─────────────┬─────────┬────────┐
│ order_id │ customer_id │  date   │ amount │
├──────────┼─────────────┼─────────┼────────┤
│    1     │     101     │2024-01-01│  $100  │
│    2     │     102     │2024-01-02│  $200  │
│   ...    │     ...     │   ...   │  ...   │
└──────────┴─────────────┴─────────┴────────┘
After Horizontal Sharding:
Shard 1 (Orders 1-3.3M):          Shard 2 (Orders 3.3M-6.6M):
┌──────────┬─────────────┬─────────┐  ┌──────────┬─────────────┬─────────┐
│ order_id │ customer_id │  date   │  │ order_id │ customer_id │  date   │
├──────────┼─────────────┼─────────┤  ├──────────┼─────────────┼─────────┤
│    1     │     101     │2024-01-01│  │  3300001 │     201     │2024-02-01│
│    2     │     102     │2024-01-02│  │  3300002 │     202     │2024-02-02│
└──────────┴─────────────┴─────────┘  └──────────┴─────────────┴─────────┘

2. Sharding Strategies Deep Dive

2.1 Range Sharding: Ordered Distribution

Range sharding divides data based on ordered ranges of a sharding key, making it intuitive and efficient for range queries.

Architecture and Implementation

Data Distribution by User ID:
┌─────────────────┐    ┌─────────────────┐    ┌─────────────────┐
│     Shard 1     │    │     Shard 2     │    │     Shard 3     │
│  Users 1-1000   │    │ Users 1001-2000 │    │ Users 2001-3000 │
│                 │    │                 │    │                 │
│  ID: 1   Alice  │    │  ID: 1001 Bob   │    │  ID: 2001 Carol │
│  ID: 2   Dave   │    │  ID: 1002 Eve   │    │  ID: 2002 Frank │
│  ID: ...  ...   │    │  ID: ...  ...   │    │  ID: ...  ...   │
└─────────────────┘    └─────────────────┘    └─────────────────┘

Routing Logic Implementation

class RangeShard:
    def __init__(self):
        self.shards = [
            {'name': 'shard1', 'min': 1, 'max': 1000, 'server': 'db1.example.com'},
            {'name': 'shard2', 'min': 1001, 'max': 2000, 'server': 'db2.example.com'},
            {'name': 'shard3', 'min': 2001, 'max': 3000, 'server': 'db3.example.com'}
        ]
    
    def get_shard(self, user_id):
        for shard in self.shards:
            if shard['min'] <= user_id <= shard['max']:
                return shard
        raise ValueError(f"No shard found for user_id: {user_id}")
    
    def execute_query(self, query, user_id=None, user_id_range=None):
        if user_id:
            # Single shard query
            shard = self.get_shard(user_id)
            return self.query_shard(shard, query)
        elif user_id_range:
            # Range query - may span multiple shards
            start_id, end_id = user_id_range
            affected_shards = []
            for shard in self.shards:
                if (shard['min'] <= end_id and shard['max'] >= start_id):
                    affected_shards.append(shard)
            return self.query_multiple_shards(affected_shards, query)

Real-World Use Cases

Time-Series Data: Perfect for logs, metrics, or events partitioned by date:

Shard 1: Jan 2024 data
Shard 2: Feb 2024 data  
Shard 3: Mar 2024 data

Geographic Distribution: Customer data partitioned by region:

Shard 1: Customers A-H (East Coast)
Shard 2: Customers I-P (Central)
Shard 3: Customers Q-Z (West Coast)

Advantages and Limitations

Advantages:

  • Efficient Range Queries: Queries like “find all orders from January” hit only relevant shards
  • Simple Routing Logic: Straightforward to determine which shard contains specific data
  • Ordered Data Access: Natural ordering is preserved within and across shards

Limitations:

  • Hot Spots: Recent data (like current month’s orders) may create uneven load
  • Data Skew: Uneven distribution if ranges don’t align with actual data patterns
  • Manual Rebalancing: Adding new shards requires careful range redistribution

Common Implementation Mistakes

Poor Range Selection: Choosing ranges without analyzing actual data distribution:

# Bad: Assuming uniform distribution
ranges = [(1, 1000), (1001, 2000), (2001, 3000)]

# Better: Analyze actual data percentiles
def calculate_optimal_ranges(data, num_shards):
    percentiles = np.percentile(data, 
                               [i * (100/num_shards) for i in range(1, num_shards+1)])
    return [(percentiles[i-1] if i > 0 else 0, percentiles[i]) 
            for i in range(num_shards)]

Static Range Boundaries: Not planning for growth and data evolution over time.

2.2 Hash Sharding: Uniform Distribution

Hash sharding uses hash functions to distribute data evenly across shards, eliminating hot spots but complicating range queries.

Hash Function Selection and Implementation

import hashlib
import mmh3  # MurmurHash3 - fast, good distribution

class HashShard:
    def __init__(self, num_shards):
        self.num_shards = num_shards
        self.shards = [f"shard_{i}" for i in range(num_shards)]
    
    def hash_function(self, key):
        # Option 1: Simple hash (good for demonstrations)
        return hash(str(key)) % self.num_shards
        
        # Option 2: MD5 (cryptographically secure but slower)
        return int(hashlib.md5(str(key).encode()).hexdigest(), 16) % self.num_shards
        
        # Option 3: MurmurHash (fast, good distribution)
        return mmh3.hash(str(key)) % self.num_shards
    
    def get_shard(self, key):
        shard_index = self.hash_function(key)
        return self.shards[shard_index]
    
    def distribute_data(self, data_items):
        distribution = {shard: [] for shard in self.shards}
        for item in data_items:
            shard = self.get_shard(item['key'])
            distribution[shard].append(item)
        return distribution

Data Flow Visualization

Input Keys → Hash Function → Shard Assignment

user_123 → hash(user_123) = 1547 → 1547 % 4 = 3 → Shard 3
user_456 → hash(user_456) = 2981 → 2981 % 4 = 1 → Shard 1  
user_789 → hash(user_789) = 4302 → 4302 % 4 = 2 → Shard 2
        Hash Ring Distribution:
            ┌─────────────┐
            │   Shard 0   │ ← Keys hashing to 0
            └─────────────┘
                   │
    ┌─────────────┐         ┌─────────────┐
    │   Shard 3   │ ←─────→ │   Shard 1   │
    └─────────────┘         └─────────────┘
                   │
            ┌─────────────┐
            │   Shard 2   │ ← Keys hashing to 2
            └─────────────┘

Advantages and Trade-offs

Advantages:

  • Even Distribution: Hash functions distribute data uniformly across shards
  • No Hot Spots: All shards receive roughly equal traffic
  • Scalable Writes: Write load is naturally distributed

Disadvantages:

  • Complex Range Queries: Finding all records in a range requires querying all shards
  • Resharding Complexity: Adding/removing shards requires rehashing most data
  • No Data Locality: Related records may end up on different shards

Range Query Handling with Hash Sharding

class HashShardWithRangeSupport:
    def execute_range_query(self, start_key, end_key):
        # Must query all shards for range queries
        results = []
        for shard in self.shards:
            shard_results = self.query_shard(
                shard, 
                f"SELECT * FROM table WHERE key BETWEEN {start_key} AND {end_key}"
            )
            results.extend(shard_results)
        
        # Sort and merge results from all shards
        return sorted(results, key=lambda x: x['key'])
    
    def optimize_range_queries(self):
        # Strategy 1: Maintain global index
        self.global_index = self.build_global_index()
        
        # Strategy 2: Use bloom filters to skip empty shards
        self.bloom_filters = self.build_bloom_filters()
        
        # Strategy 3: Cache frequent range query results
        self.range_cache = LRUCache(maxsize=1000)

2.3 Consistent Hashing: Elastic Scaling

Consistent hashing solves the resharding problem by minimizing data movement when shards are added or removed.

The Ring Architecture

                    Shard A (hash: 100)
                           ●
                      ╭─────────╮
                  ╭───╯         ╰───╮
             ╭────╯                 ╰────╮
        ●───╯                           ╰───● Shard B (hash: 300)
  Shard D                                     
 (hash: 50)                                   
        ●───╮                           ╭───●
             ╰────╮                 ╭────╯
                  ╰───╮         ╭───╯    Shard C (hash: 250)
                      ╰─────────╯
                           ●
                    
Data Placement Rules:
- user_1 (hash: 75)  → Shard A (next clockwise)
- user_2 (hash: 150) → Shard B (next clockwise)
- user_3 (hash: 275) → Shard C (next clockwise)
- user_4 (hash: 25)  → Shard D (next clockwise)

Implementation with Virtual Nodes

import bisect
import hashlib

class ConsistentHashRing:
    def __init__(self, nodes=None, replicas=150):
        """
        replicas: Number of virtual nodes per physical node
        Higher values = better distribution, more memory usage
        """
        self.replicas = replicas
        self.ring = []  # Sorted list of hash values
        self.nodes = {}  # hash_value -> actual_node mapping
        
        if nodes:
            for node in nodes:
                self.add_node(node)
    
    def hash_fn(self, key):
        """Use MD5 for consistent hashing across different systems"""
        return int(hashlib.md5(key.encode('utf-8')).hexdigest(), 16)
    
    def add_node(self, node):
        """Add a node with virtual replicas for better load distribution"""
        for i in range(self.replicas):
            replica_key = f"{node}:replica:{i}"
            hash_val = self.hash_fn(replica_key)
            
            # Insert in sorted order
            bisect.insort(self.ring, hash_val)
            self.nodes[hash_val] = node
        
        print(f"Added node {node} with {self.replicas} virtual nodes")
    
    def remove_node(self, node):
        """Remove a node and all its virtual replicas"""
        removed_count = 0
        for i in range(self.replicas):
            replica_key = f"{node}:replica:{i}"
            hash_val = self.hash_fn(replica_key)
            
            try:
                index = self.ring.index(hash_val)
                self.ring.pop(index)
                del self.nodes[hash_val]
                removed_count += 1
            except ValueError:
                pass  # Hash not found, skip
        
        print(f"Removed node {node} ({removed_count} virtual nodes)")
    
    def get_node(self, key):
        """Find the node responsible for a given key"""
        if not self.ring:
            return None
        
        hash_val = self.hash_fn(key)
        
        # Find the first node clockwise from the key's hash
        index = bisect.bisect_right(self.ring, hash_val)
        if index == len(self.ring):
            index = 0  # Wrap around to the beginning
        
        return self.nodes[self.ring[index]]
    
    def get_nodes_for_key(self, key, num_replicas=3):
        """Get multiple nodes for replication"""
        if not self.ring or num_replicas <= 0:
            return []
        
        hash_val = self.hash_fn(key)
        nodes = []
        seen_physical_nodes = set()
        
        # Start from the key's position and go clockwise
        start_index = bisect.bisect_right(self.ring, hash_val)
        
        for i in range(len(self.ring)):
            index = (start_index + i) % len(self.ring)
            physical_node = self.nodes[self.ring[index]]
            
            if physical_node not in seen_physical_nodes:
                nodes.append(physical_node)
                seen_physical_nodes.add(physical_node)
                
                if len(nodes) >= num_replicas:
                    break
        
        return nodes
    
    def show_distribution(self, keys):
        """Analyze how keys are distributed across nodes"""
        distribution = {}
        for key in keys:
            node = self.get_node(key)
            distribution[node] = distribution.get(node, 0) + 1
        
        print("Key distribution:")
        for node, count in sorted(distribution.items()):
            percentage = (count / len(keys)) * 100
            print(f"  {node}: {count} keys ({percentage:.1f}%)")
        
        return distribution

# Example usage demonstrating elastic scaling
def demonstrate_consistent_hashing():
    # Initial setup with 3 nodes
    initial_nodes = ['shard_1', 'shard_2', 'shard_3']
    ring = ConsistentHashRing(initial_nodes)
    
    # Sample data keys
    test_keys = [f"user_{i}" for i in range(1000)]
    
    print("=== Initial Distribution ===")
    initial_dist = ring.show_distribution(test_keys)
    
    # Add a new node (simulating scaling out)
    print("\n=== Adding shard_4 ===")
    ring.add_node('shard_4')
    new_dist = ring.show_distribution(test_keys)
    
    # Calculate data movement
    moved_keys = 0
    for key in test_keys:
        old_node = None
        for node, count in initial_dist.items():
            if ring.get_node(key) != node:
                continue
            old_node = node
            break
        
        new_node = ring.get_node(key)
        if old_node != new_node:
            moved_keys += 1
    
    print(f"\nData movement: {moved_keys}/{len(test_keys)} keys moved ({moved_keys/len(test_keys)*100:.1f}%)")
    print(f"Theoretical minimum: {len(test_keys)/len(initial_nodes):.0f} keys ({100/len(initial_nodes):.1f}%)")

Virtual Nodes Deep Dive

Virtual nodes solve the uneven distribution problem in basic consistent hashing:

Without Virtual Nodes (Poor Distribution):
Node A: 60% of data
Node B: 25% of data  
Node C: 15% of data
With Virtual Nodes (150 per node):
Node A: 33.2% of data
Node B: 33.5% of data
Node C: 33.3% of data

Benefits of Virtual Nodes:

  • More even load distribution
  • Smoother data movement during scaling
  • Better fault tolerance

Trade-offs:

  • Increased memory usage for the ring structure
  • More complex implementation
  • Slight performance overhead

3. Advanced Sharding Challenges and Solutions

3.1 The Rebalancing Challenge

Rebalancing shards is one of the most complex operational challenges in distributed systems. It involves moving data while maintaining system availability and consistency.

Data Movement Strategies

Stop-the-World Migration:

def stop_the_world_migration(source_shard, target_shard, data_range):
    # Simplest but causes downtime
    with system_lock():
        data = source_shard.extract_data(data_range)
        target_shard.load_data(data)
        source_shard.delete_data(data_range)
        update_routing_table(data_range, target_shard)

Live Migration with Dual Writes:

class LiveMigration:
    def __init__(self, source_shard, target_shard, data_range):
        self.source = source_shard
        self.target = target_shard  
        self.range = data_range
        self.migration_state = "PREPARING"
    
    def start_migration(self):
        # Phase 1: Bulk copy existing data
        self.migration_state = "COPYING"
        self.bulk_copy_data()
        
        # Phase 2: Enable dual writes
        self.migration_state = "DUAL_WRITE"
        self.enable_dual_writes()
        
        # Phase 3: Sync any missed changes
        self.migration_state = "SYNCING"
        self.sync_delta_changes()
        
        # Phase 4: Switch reads to target
        self.migration_state = "SWITCHING"
        self.switch_reads_to_target()
        
        # Phase 5: Disable dual writes, cleanup source
        self.migration_state = "CLEANUP"
        self.cleanup_source()
        
        self.migration_state = "COMPLETE"
    
    def write_data(self, key, data):
        if self.migration_state == "DUAL_WRITE":
            # Write to both shards during migration
            self.source.write(key, data)
            self.target.write(key, data)
        elif key in self.range and self.migration_state in ["SWITCHING", "CLEANUP"]:
            self.target.write(key, data)
        else:
            self.source.write(key, data)

Zero-Downtime Migration Patterns

Read-After-Write Consistency:

class MigrationAwareRouter:
    def read_data(self, key):
        # During migration, check target first for latest data
        if self.is_migrating(key):
            target_data = self.target_shard.read(key)
            if target_data is not None:
                return target_data
        
        return self.source_shard.read(key)
    
    def write_data(self, key, data):
        if self.is_migrating(key):
            # Dual write during migration
            self.source_shard.write(key, data)
            self.target_shard.write(key, data)
        else:
            self.get_primary_shard(key).write(key, data)

3.2 Cross-Shard Transactions and Joins

Sharding breaks traditional ACID transactions when data spans multiple shards.

Distributed Transaction Patterns

Two-Phase Commit (2PC):

class TwoPhaseCommitCoordinator:
    def execute_transaction(self, operations):
        participating_shards = self.get_affected_shards(operations)
        transaction_id = self.generate_transaction_id()
        
        # Phase 1: Prepare
        prepared_shards = []
        try:
            for shard in participating_shards:
                if shard.prepare(transaction_id, operations):
                    prepared_shards.append(shard)
                else:
                    raise TransactionAbortException(f"Shard {shard} failed to prepare")
            
            # Phase 2: Commit
            for shard in prepared_shards:
                shard.commit(transaction_id)
                
        except Exception as e:
            # Abort all prepared shards
            for shard in prepared_shards:
                shard.abort(transaction_id)
            raise e

Saga Pattern for Long-Running Transactions:

class SagaOrchestrator:
    def __init__(self):
        self.steps = []
        self.compensations = []
    
    def add_step(self, action, compensation):
        self.steps.append(action)
        self.compensations.append(compensation)
    
    def execute(self):
        completed_steps = []
        try:
            for i, step in enumerate(self.steps):
                result = step.execute()
                completed_steps.append(i)
                
        except Exception as e:
            # Compensate in reverse order
            for i in reversed(completed_steps):
                try:
                    self.compensations[i].execute()
                except Exception as comp_error:
                    logging.error(f"Compensation failed: {comp_error}")
            raise e

# Example: Transfer money between users on different shards
def transfer_money_saga(from_user_id, to_user_id, amount):
    saga = SagaOrchestrator()
    
    # Step 1: Debit from source account
    saga.add_step(
        action=lambda: debit_account(from_user_id, amount),
        compensation=lambda: credit_account(from_user_id, amount)
    )
    
    # Step 2: Credit to destination account  
    saga.add_step(
        action=lambda: credit_account(to_user_id, amount),
        compensation=lambda: debit_account(to_user_id, amount)
    )
    
    # Step 3: Record transaction
    saga.add_step(
        action=lambda: record_transaction(from_user_id, to_user_id, amount),
        compensation=lambda: delete_transaction_record(from_user_id, to_user_id, amount)
    )
    
    saga.execute()

Cross-Shard Query Optimization

Scatter-Gather Pattern:

class CrossShardQueryEngine:
    def execute_cross_shard_query(self, query):
        # Parse query to identify affected shards
        affected_shards = self.analyze_query_shards(query)
        
        # Execute query on all relevant shards in parallel
        futures = []
        for shard in affected_shards:
            future = self.thread_pool.submit(shard.execute_query, query)
            futures.append((shard, future))
        
        # Gather and merge results
        all_results = []
        for shard, future in futures:
            try:
                results = future.result(timeout=30)  # 30 second timeout
                all_results.extend(results)
            except Exception as e:
                logging.error(f"Query failed on shard {shard}: {e}")
                # Decide whether to fail fast or continue with partial results
        
        # Apply global ordering, limits, etc.
        return self.merge_and_process_results(all_results, query)
    
    def optimize_join_queries(self, query):
        """Optimize joins across shards"""
        if self.is_co_located_join(query):
            # Data is on same shard, execute locally
            return self.execute_local_join(query)
        else:
            # Implement broadcast join or shuffle join
            return self.execute_distributed_join(query)

3.3 Monitoring and Observability

Effective sharding requires comprehensive monitoring to detect imbalances, performance issues, and failures.

Key Metrics to Track

class ShardingMetrics:
    def __init__(self):
        self.metrics = {
            'shard_sizes': {},           # Data volume per shard
            'query_latencies': {},       # Response times per shard  
            'write_throughput': {},      # Writes per second per shard
            'read_throughput': {},       # Reads per second per shard
            'error_rates': {},           # Error percentage per shard
            'replication_lag': {},       # If using replication
            'hot_keys': {},              # Most accessed keys
            'cross_shard_queries': 0,    # Expensive operations
        }
    
    def detect_hot_shards(self, threshold=2.0):
        """Detect shards with disproportionate load"""
        avg_load = sum(self.metrics['read_throughput'].values()) / len(self.metrics['read_throughput'])
        hot_shards = []
        
        for shard, load in self.metrics['read_throughput'].items():
            if load > avg_load * threshold:
                hot_shards.append((shard, load, load/avg_load))
        
        return hot_shards
    
    def suggest_rebalancing(self):
        """Analyze metrics and suggest rebalancing actions"""
        suggestions = []
        
        # Check for size imbalances
        sizes = list(self.metrics['shard_sizes'].values())
        if max(sizes) > min(sizes) * 3:  # 3x size difference
            suggestions.append("Consider range-based rebalancing due to size imbalance")
        
        # Check for hot spots
        hot_shards = self.detect_hot_shards()
        if hot_shards:
            suggestions.append(f"Hot shards detected: {[s[0] for s in hot_shards]}")
        
        # Check cross-shard query frequency
        if self.metrics['cross_shard_queries'] > 1000:  # Per hour
            suggestions.append("High cross-shard query rate - consider denormalizing data")
        
        return suggestions

Track your progress

Mark this subtopic as completed when you finish reading.