[TFF learning] Official tutorial jupyter operation record_Federated learning image classification task_1

chapter_1 Check whether the system environment and third-party libraries are installed

When running python programs on Web servers, gui applications and jupyter notebooks,
"RuntimeError: This event loop is already running". Probably because of nesting of loops.
The solution is: the following procedure.
This module patches asyncio to allow nested use of asyncio.run and loop.run_until_complete.

import nest_asyncio

Import the necessary libraries.
The collections module also provides several additional data types: Counter, deque, defaultdict, namedtuple, and OrderedDict, etc.

import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff


If the first parameter of tff.federated_computation() is a function, it returns a TFF calculation instance based on this function

a=tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

So far, our environment has been configured.

If there is an error, you can refer to this blog to install TFF and TF, which is convenient and fast.

chapter_2 prepare input data

Joint data is often non-i.i.d, which presents a unique set of challenges.

TFF has built in a federated version of MNIST that contains a version of the original NIST dataset that has been reprocessed using Leaf so that the data is keyed in by the original writers of the numbers.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

The data imported here is only used to simulate the client's data. We can view the number of client id s in the training set


The element structure of training set and test set is: label (int) and 28*28 picture (float32)

OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)),
              TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)),
              TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])

Create samples based on customer id and view the number of samples contained in the customer sub-dataset

example_dataset = emnist_train.create_tf_dataset_for_client(

View the first element

example_element = next(iter(example_dataset))
label = example_element['label'].numpy()
img = example_element['pixels'].numpy()
import matplotlib.pyplot as plt

2.1 Exploring Heterogeneity in Joint Data

Federated data is usually non-IID and users often have different data distributions depending on usage patterns. Some clients may have fewer training examples on the device and lack local data, while some clients will have more than enough training examples. Let's use our available EMNIST data to explore the notion of data heterogeneity typical of federated systems. It is important to note that this in-depth analysis of customer data is only available to us as this is a simulated environment and we have all the data available locally. In a true production federated environment, you would not be able to inspect individual client data.

First, let's grab a sample of customer data to get a feel for the example on a simulated device. Because the dataset we used was typed by a unique author, a customer's data represented a person's handwriting on a sample of the numbers 0 to 9, simulating a user's unique "use pattern."

figure = plt.figure(figsize=(20,8))
for example in example_dataset.take(40):
    # Select the first 40 samples to display
    label = example['label'].numpy()
    # print(label)

Now let's visualize the number of samples per MNIST numeric label on each client. In a federated learning setting, the number of samples on each client can vary widely, depending on user behavior.

f = plt.figure(figsize=(12,7))
f.suptitle('Label Count for a Sample of Clients')
# Select the first 6 users
for i in range(client_num):
    client_dataset=emnist_train.create_tf_dataset_for_client( \
    for example in client_dataset:
        label = example['label'].numpy()
        # There are several labels for statistics 1-9, which are stored in a dictionary variable
    for j in range(10):

Now let's visualize the average image per client for each MNIST label. This code will generate the average of each pixel value for all user samples for a label. We'll see that one client's average image of a number will look different from another client's average image of the same number due to each person's unique handwriting style. We can think about how each round of local training will push the model in a different direction on each client because we are learning from that user's own unique data in that local round. Later in the tutorial, we'll see how to get each update of the model from all clients and aggregate them into our new global model, which learns from each client's own unique data.

Each customer has a different average image, which means each customer pushes the model to train locally in its own direction

for i in range(6):
    client_dataset = emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[i])
    for example in client_dataset:
        label = example['label'].numpy()
        img = example['pixels'].numpy()
    f = plt.figure(i,figsize=(12,5))
    f.suptitle(f"Client{i} Mean Image Label")
    for j in range(10):
        mean_img = np.mean(plot_data[j],0)

User data can be noisy and labels are unreliable. For example, looking at the data for client #2 above, we can see that for label 2, there may be some mislabeled examples, creating a noisier average image.

2.2 Preprocessing input data

Since the data is already a tf.data.Dataset, preprocessing can be done using a dataset transformation. Here we flatten the 28x28 image into an array of 784 elements, shuffle individual examples, organize them into batches, and rename features from pixels and labels to x and y for use with Keras. We also repeated the dataset to run several epoch s.

SHUFFLE_BUFFER=100 #Buffer size when shuffling
PREFECTH_BUFFER=10 #prefetch element

def preprocess(dataset):
    # dataset = tf.data.Dataset()
    def batch_formate_fn(element):
        # Turn the elements into x,y. And expand the image into 784 elements
        return collections.OrderedDict(
    return dataset.repeat(NUM_EPOCH) \
        .shuffle(SHUFFLE_BUFFER,seed=1) \
            .batch(BATCH_SIZE) \
                .map(batch_formate_fn) \

Let's take the first batch and test whether the function is normal.

preprocess_example_dataset = preprocess(example_dataset)
sample_batch = tf.nest.map_structure(lambda x:x.numpy(),
              array([[1., 1., 1., ..., 1., 1., 1.],
                     [1., 1., 1., ..., 1., 1., 1.],
                     [1., 1., 1., ..., 1., 1., 1.],
                     [1., 1., 1., ..., 1., 1., 1.],
                     [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)),
                     [1]], dtype=int32))])

We have almost all the building blocks to build a joint dataset.

One way to provide federated data to TFF in a simulation is simply as a Python list. Whether as a list or as a tf.data.Dataset, each element of the list holds data for a single user. The structure of tf.data.Dataset will be used below.

def make_federated_data(client_data,client_idx):
    return [
            ) for x in client_idx

This function will return all processed client data

Now, how do we choose the client?

In a typical joint training scenario, we are dealing with a potentially large number of user devices, only a fraction of which may be available for training at a given point in time. For example, when the client device is a mobile phone that only participates in training when it is plugged in, the network is clear, and it is idle
This is the case when.

Of course, we are in a simulated environment and all data is available locally. Typically, when running a simulation, we simply sample a random subset of customers participating in each round of training, which is usually different.

Here we sample a set of clients once and reuse the same set in each round to speed up convergence (deliberately overfitting the data for these few users).

As we found by studying the paper on federated averaging algorithms, it can take longer to achieve convergence in a system that randomly samples a subset of clients in each round. We leave it as an exercise for the reader, modifying this tutorial to simulate random sampling - it's easy to do (once you do, keep in mind that it may take a while for the model to converge).

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train,sample_clients)
print(f'Number of client datasets: {len(federated_train_data)}')
print(f'First dataset: {federated_train_data[0]}')

Number of client datasets: 10
First dataset: <PrefetchDataset shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

chapter_3 Building a model with Keras

from tensorflow.keras import models,layers,losses,metrics
def create_keras_model():
    # The network consists of a fully connected layer with input 784 and output 10.
    return models.Sequential([

In order to use any model with TFF, it needs to be wrapped in tff.learning. In an instance of the Model interface, it exposes methods for marking up forward transfer of models, metadata attributes, and so on, similar to Keras, but also introduces Additional elements, such as methods to control the process of computing joint metrics. We don't have to worry about this for now;

If you have a Keras model like the one we just defined above, you can have TFF wrap it by calling tff.learning.from_keras_model , passing the model and batch of sample data as arguments, as shown below.

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(

Explain the parameters of from_keras_model:

  • keras_model: model built by keras
  • input_spec: the type of input data, here select the data type after preprocessing
  • loss: loss function, here is the cross-entropy loss in the chosen multi-class problem
  • metrics: classification accuracy, 0-1

The API for the above functions can be viewed on this website:

chapter_4 Train the model on federated data

Now that we have a model wrapped as tff.learning.Model to use with TFF, we can let TFF construct a federated averaging algorithm by calling a helper function

from tensorflow_federated import learning
from tensorflow import optimizers

4.1 Define a federated training process object.

This part of the 2022.09.08 official website uses tff.learning.algorithms.build_weighted_fed_avg, which is available in the higher version of TFF. This article uses TFF0.19.0 without this function, and replaces it with tff.learning.build_federated_averaging_process.

Keep in mind that the parameter must be a constructor (such as model_fn above) rather than a constructed instance so that the model can be constructed in the context of TFF control (if you do, we encourage you to read about the custom algorithm follow-up tutorial).

Here is an important note about the federated averaging algorithm, there are 2 optimizers: a _clientoptimizer and a _serveroptimizer. _clientoptimizer is only used to compute local model updates on each client. _serveroptimizer applies average updates to the global model on the server. In particular, this means that the choice of optimizer and learning rate used may need to be different from what you used to train the model on the standard i.i.d dataset. We recommend starting with regular SGD, possibly using a smaller learning rate than usual. The learning rate we used was not carefully tuned, feel free to experiment.

iterative_process = learning.build_federated_averaging_process(
    client_optimizer_fn=lambda: optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: optimizers.SGD(learning_rate=1)

Their APIs are similar. The official website can't find the API of build_federated_averaging_process. Let's learn the API of the function build_weighted_fed_avg here:

  • What just happened?

:TFF builds a pair of joint computations and packages them into a tff.templates.IterativeProcess , where these computations are available as a pair of properties initialize and next .

In short, federated computations are programs in TFF's internal language that can express various federated algorithms (you can find more about this in the Custom Algorithms tutorial). In this case, the two computations generated and packed into iterative_process implement joint averaging.

The goal of TFF is to define computations so that they can be performed in a real federated learning setting, but currently we only implement simulation runs locally. To perform calculations in the simulator, you simply call it like a Python function. This default interpreted environment is not designed for high performance, but is sufficient for this tutorial;
We hope to provide a higher-performance simulation runtime to facilitate larger-scale studies in future releases.

Let's start with the initialization calculation. As with all joint computations, you can think of it as a function. The calculation takes no arguments and returns a result: a state representation of the federated averaging process on the server. While we don't want to delve into the details of TFF, it might be instructive to see what this state looks like. You can visualize it as follows.

( -> <

Although these types of signatures may seem a little mysterious at first glance, you can recognize that the state of the server is changed from global_ Model_ Weights (the initial model parameters of MNIST that will be distributed to all devices), some empty parameters (such as the distributor, which manages server-to-client communication), and the terminator component. The finalizer is used to control the server's logic to update its model at the end of a round, and contains an integer representing how many round s the FedAvg has occurred.

4.2 Invoke the initialization computation to construct the server state.

Above, we constructed a joint computing object IterativeProcess, this object contains 2 methods. initialize and next use.

First we use initialize to construct an initial federated learning service state.

state = iterative_process.initialize()
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:60: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`

WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:60: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`

Next, we introduce the second method next of IterativeProcess, which stands for Federated Averaging, which includes pushing the server state (including model parameters) to the client, training on its local data on the device, collecting and Average model updates and generate new updated models on the server.

Conceptually, you can think of next as having a functional type signature as shown below.


How to understand the above formula?

-> Left: Enter the state of the federated learning server SERVER_STATE, federated learning data FEDERATED_DATA
-> Right: output the updated federated learning server state SERVER_STATE, federated learning training state TRAINING_METRICS

In particular, next() should not be treated as a function running on the server, but as a declarative function representation of the entire de-centralized calculation. Some inputs are provided by the server (SERVER_STATE), but each participating device Has its own local dataset.

4.3 Run a training round and visualize the results.

We can use the federated learning data already generated above for user samples.

state,metrics = iterative_process.next(state,federated_train_data)
print('training status:')
training status:

OrderedDict([('broadcast', ()),
              OrderedDict([('mean_value', ()), ('mean_weight', ())])),
              OrderedDict([('sparse_categorical_accuracy', 0.56831276),
                           ('loss', 1.469168)])),
             ('stat', OrderedDict([('num_examples', 4860)]))])

4.4 Then let's train for 10 epochs:

state = iterative_process.initialize()
for round_num in range(0,NUM_ROUNDS):
    print('round', round_num+1,'loss:',metrics['train']['loss'])

round 1 loss: 5.2703967
round 2 loss: 4.575706
round 3 loss: 4.115935
round 4 loss: 3.4195397
round 5 loss: 2.9929774
round 6 loss: 2.7245579
round 7 loss: 2.3100283
round 8 loss: 2.102916
round 9 loss: 1.8347658
round 10 loss: 1.6228148

It can be seen that during the training process, the loss is constantly decreasing.

4.5 [Extended] Randomly select client s at each epoch

Let's test a random sample of NUM_CLIENTS customers from all customers:

import numpy as np

client_num_all = len(emnist_train.client_ids)
random_clients_idx = np.random.randint(client_num_all,size=NUM_CLIENTS)
client_ids_numpy = np.array(emnist_train.client_ids)

It can be seen that random sampling is available, and the above code is added to the loop. Each loop randomly selects a batch of customers for training

state = iterative_process.initialize()
for round_num in range(0,NUM_ROUNDS):

    client_num_all = len(emnist_train.client_ids)
    random_clients_idx = np.random.randint(client_num_all,size=NUM_CLIENTS)
    client_ids_numpy = np.array(emnist_train.client_ids)
    federated_train_data = make_federated_data(emnist_train,random_clients)

    print('round', round_num+1,'loss:',metrics['train']['loss'])

round 1 loss: 5.11803
round 2 loss: 4.7252603
round 3 loss: 4.249384
round 4 loss: 3.8406074
round 5 loss: 3.0722911
round 6 loss: 2.7794824
round 7 loss: 2.7126424
round 8 loss: 2.2197828
round 9 loss: 2.245559
round 10 loss: 2.3454423

It can be seen that, compared with the fixed client training in 4.4, when training with random clients in each round, the loss drop rate is significantly reduced, and more rounds are required to converge.

chapter_5 displays the training results on tensorboard

Next, let's visualize metrics from these federated computations using Tensorboard. Let's first create the directory and corresponding summary writer to write metrics into.

state = iterative_process.initialize()

with summary_write.as_default():
    for round_num in range(NUM_ROUNDS):
        for acc,value in metrics['train'].items():

!ls {logdir}
%tensorboard --logdir {logdir} --port=0

You can open the tensorboard interface through vscode, and it can be seen that the loss continues to decrease and the acc gradually increases.
The loss and acc here refer to the overall loss and overall acc on the server side

Tags: jupyter

Posted by flying_circus on Fri, 09 Sep 2022 22:08:17 +0530