.. _guide-batching:

Batching
=========

In practice, we usually need to convert a collection of small graph into a large graph where each original small graph
is a connected component of the large graph. This operation is called `batching` in graph deep learning and is widely
applied to improve computing efficiency.

``GraphData`` provides interfaces for batching and unbatching graphs for training and inference. The ``to_batch()``
function takes a list of ``GraphData`` instances and returns a single ``GraphData`` which is the merged large graph.
On the other hand, users may use ``from_batch()`` to decompose a large graph generated by merging small graphs into a
list of ``GraphData``.

The following code snippet shows an example:

.. code-block::

    g_list = []
    batched_edges = []
    graph_edges_list = []
    # Build a number of graphs
    for i in range(5):
        g = GraphData()
        g.add_nodes(10)
        for j in range(10):
            g.add_edge(src=j, tgt=(j + 1) % 10)
            batched_edges.append((i * 10 + j, i * 10 + ((j + 1) % 10)))
        g.node_features['idx'] = torch.ones(10) * i
        g.edge_features['idx'] = torch.ones(10) * i
        graph_edges_list.append(g.get_all_edges())
        g_list.append(g)

    # Test to_batch
    batch = to_batch(g_list)

    target_batch_idx = []
    for i in range(5):
        for j in range(10):
            target_batch_idx.append(i)

    # Expected behaviors
    assert batch.batch == target_batch_idx
    assert batch.get_node_num() == 50
    assert batch.get_all_edges() == batched_edges

    # Un-batching
    graph_list = from_batch(batch)

    for i in range(len(graph_list)):
        g = graph_list[i]
        # Expected behaviors
        assert g.get_all_edges() == graph_edges_list[i]
        assert g.get_node_num() == 10
        assert torch.all(torch.eq(g.node_features['idx'], torch.ones(10) * i))
        assert torch.all(torch.eq(g.edge_features['idx'], torch.ones(10) * i))

