IMPROVING THE LEARNING PERFORMANCE OF CLIENT'S LOCAL DISTRIBUTION IN CYCLIC FEDERATED LEARNING

Cyclic federated learning based on distribution information sharing and knowledge distillation (CFL_DS_KD) aims to address the challenges of non-iid data distribution and reduce communication requirements. However, when client data is extremely heterogeneous and scarce, it becomes challenging for clients to fully learn the distribution of local data using GANs, thereby affecting the overall model performance. To overcome this limitation, we propose a transfer learning approach where clients first pretrain their generators on a source domain and then fine-tune them on their local datasets. Our results on the classification of Alzheimer’s disease demonstrate that this method effectively improves client distribution learning performance and enhances the overall model performance.


INTRODUCTION
Deep learning has found widespread applications in intelligent healthcare (Miotto et al., 2017), including disease prediction, diagnosis, treatment, and prognosis.However, training effective deep learning models often requires large centralized datasets, which pose a significant challenge in areas where data privacy is crucial.Due to the sensitive nature of medical data, patient data from different hospitals cannot be exchanged or centrally stored.As a result, traditional deep learning models lack publicly shared medical datasets for training.Federated learning has emerged as a promising solution to address the privacy concerns associated with data.By enabling distributed learning, federated learning allows multiple organizations to collaboratively train a global model while preserving data privacy (Yang et al., 2019).However, due to the non-iid nature of datasets from different institutions, local models trained on individual datasets may overfit, leading to poor generalization of the global model.Distribution sharing among clients is a promising approach to address the non-iid problem.However, if the local client's data is scarce and extremely heterogeneous, the ability of the local client to learn the local distribution will be compromised, resulting in poor quality of the shared distribution information.In this work, we propose to utilize transfer learning to improve the learning performance of client's local distribution.Specifically, we first utilize a GAN model to learn the data distribution in the source domain and then fine-tune it in the target domain.This process aims to enhance the ability of the GAN model to learn the local data distribution.Subsequently, we apply the improved GAN model in the cyclic federated learning method based on the distribution of information sharing and knowledge distillation (CFL_DS_KD) (Yu et al., 2022) for classification of Alzheimer's disease.

RELATED WORKS
The non-iid challenge in federated learning Federated learning (FL) (McMahan et al., 2017) involves training statistical models over remote de-vices or siloed data centers, such as mobile phones or hospitals, while keeping data localized.A major challenge in FL is that the data across clients is not identically and independently distributed (non-iid).In response to noniid problems, existing research has mainly solved the problems at the algorithm and data levels.The algorithm-level solutions mainly include objective function modification and solution mode optimization.Objective function modification in-volves adding regularization terms on the client side.A trade-off has been achieved between optimizing local models and reducing the differences between local models and global models to solve the non-independent homogeneous distribution of data at each node.For example, FedProx (Li et al., 2020) has been proposed to corrects the client-side drift that occurs in FedAvg (McMahan et al., 2017) by restricting the Euclidean distances between local models and global models as proximal terms.This means that the local updates do not excessively deviate from the global models, which alleviates any inconsistencies in the client-side data and improves the stability of global model convergence.FedCurv (Shoham et al., 2019) uses Fisher information from global models obtained during the previous rounds of training to weight the distances, which can reduce excessive errors in the model parameters.SCAF-FOLD (Karimireddy et al., 2020) has been proposed to improve the FedProx by adding a control variable on the client side.This control variable can take either the gradient norm of global models on local datasets or the Euclidean distances between local and global models, thus preventing local models from deviating from the globally correct training direction.These methods can improve the performance of federated learning for model learning on non-iid datasets to some extent, but the degree of improvement is limited by the consistency of the client-side data sampling.
In solution optimization, the good performance of federated learning models is mainly achieved by improving the server-side aggregation method.FedAvg determines client aggregation weights based on the size of clients 'datasets.However, in non-iid scenarios, this aggregation method leads to a significant decrease in the performance of the global model.For this reason, most scholars have aimed to seek better aggregation method.
In ABAvg (Xiao et al., 2021), the server-side tests the accuracy of temporary models on validation datasets to obtain the accuracy of the models on the client side and then normalizes them before aggregating all parameters.FedMA (Wang et al., 2020) uses Bayesian nonparametric methods to match and average weights in a hierarchical manner.FedAvgM (Tsu et al., 2019) applies momentum when updating global models on a server.FedNova (Wang et al., 2020) normalizes local updates before averaging.However, these methods have limited success in improving the performance of global models (Karimireddy et al., 2020), so some scholars have proposed approaches that evade this problem, such as personalized federated learning, multitask federated learning and federated meta-learning, which can also improve the performance of federated learning on noniid data to some extent.

Transfer learning for medical data
Transfer learning (TL) stems from cognitive research, which uses the idea, that knowledge is transferred across related tasks to improve performances on a new task.The formal definition of TL is defined by Pan and Yang with the notions of domains and tasks.A domain consists of a feature space  and marginal probability distribution (), where  = { 1 , … ,   } ∈  .Given a specific domain denoted by  = {, ()}, a task is denoted by  = {, (•)} where  is a label space and (•) is an objective predictive function.Given a source domain   and learning task   , a target domain   and learning task   , transfer learning aims to improve the learning of the target predictive function   (•) in   by using the knowledge in   and   (Pan. et al., 2020).
There have been lots of studies applying transfer learning to medical image processing.Swati et al. use pre-trained deep CNN model and propose a block-wise fine-tuning strategy based on transfer learning which is evaluated on T1-weighted contrast-enhanced magnetic resonance images (CE-MRI) benchmark dataset.Experimental results show that their proposed method outperforms state-of-the-art classification on the CE-MRI dataset.da Nóbrega et al. trained several CNN (e.g.VGG16, MobileNet, ResNet50, DenseNet169, etc.) on the ImageNet dataset, converted them into feature extractors and applied on the LIDC/IDRI nodule images.Hassan et al. proposed an efficient and accurate approach for medical image modality classification which is developed using transfer learning concept with pretrained ResNet50 Deep learning model for optimized features extraction followed by linear discriminant analysis classification (TLRN-LDA).Gessert et al. demonstrate that convolutional neural networks and transfer learning can be used to identify cancer tissue with confocal laser microscopy and show that there is no generally optimal transfer learning strategy and model as well as task-specific engineering is required.

METHOD Learning client data distribution through transfer learning
To acquire knowledge about the distribution of hospital data, deep learning-based generator models are commonly employed.Generators are highly effective for data augmentation as they can learn the distribution information of data and generate data that aligns with the actual distribution.Generative adversarial networks (GANs) are a prevalent class of deep neural network generators known for their re-markable capabilities in image enhancement and image-to-image conversion.In our study, we utilize GANs as data generators to capture the data distribution information from local clients.However, considering the limited availability and heterogeneity of local data, allowing GANs to directly train on local data may pose challenges in fully capturing the underlying distribution.Therefore, we propose using transfer learning to enhance the learning of distribution information by local clients.Specifically, as shown in Fig. 1, we first allow local generators to learn distribution knowledge in the source domain and then fine-tune them using data from local clients.Model pre-training serves to minimize internal dimensions and implicitly influences the model's induction bias.In classical supervised learning, models often possess a strong inductive bias, such as the local connectivity assumptions in convolutional neural networks (CNNs) and recurrent neural networks (RNNs).Pretraining provides an inductive bias for downstream tasks, which often have limited labeled samples, enabling the pre-trained model parameters (with hundreds of millions of samples) to generalize well when finetuned with a small amount of data.The core idea of our method involves pre-training the GAN in the source domain to extract features and initialize the GAN network parameters.Subsequently, fine-tuning is performed in the target domain.Transfer learning, in this context, aims to enhance the model's performance by identifying differences between datasets and leveraging transferable knowledge.Generative adversarial networks, designed to generate similar data by approximating the feature distribution of the target samples, typically require a sufficient number of target samples.When the target sample size is small, GANs often face mode collapse issues.However, transfer learning can alleviate this problem in GANs and reduce the stringent requirement of similarity between the source and target domain data.
In general, by leveraging transfer learning, GANs can effectively learn the distribution of local datasets, thereby enhancing the quality of generated medical images by local GANs.These improved models can then be applied in the context of cyclic federated learning.Then, the  + 1 client employs knowledge distillation, utilizing   as a teacher model to guide the training of  +1 on the virtual dataset ′ +1 , as shown in Fig. 3.After the process of knowledge distillation, the updated model of client  + 1 continues to train on the local dataset  +1 .

Fig. 3. The process of knowledge distillation on virtual shared data generated by GAN.
The training goal of the cyclic federated learning method based on the distribution of information sharing and knowledge distillation was the minimization of the total loss function (1). (1) (2) ( +1 ,   ) = ∑   (;  +1 ,   ) ∈′ +1 (4) In (1),  +1 represents the loss of client  + 1 during training on the virtual dataset ′ +1 using the model   of client .As shown in equation ( 3), this loss includes both the soft loss during the knowledge distillation process and the hard loss of the student model. +1 represents the loss of client  + 1 during training on the local dataset.Equations ( 4) and ( 5) describe the optimization process which indicates that  +1 (−1) is first optimized through training on the virtual dataset to obtain the updated model  +1 () , and then further updated on the local dataset to obtain the final model  +1 () .Then, the updated model  +1 () transmits to next client.

Development environment and datasets
Our deep learning model was constructed using the popular deep learning framework PyTorch, version 1.6.0,along with Python, version 3.7.1.We adopted the identical network configuration as de-scribed in the referenced paper (L.Yu et al., 2022).Specifically, we employed a cyclic federated learning framework, utilizing a Kafka cluster as the medium for exchanging model parameters.The GAN model we used is a conditional Wasserstein Generative Adversarial Network with Gradient Penalty (WGAN-GP).We utilized two distinct medical datasets: the Alzheimer's disease dataset from the Kaggle contest (url: https://www.kaggle.com/datasets/tourist55/alzheimers-dataset-4-class-of-images)and the ADNI MRI dataset (url: https://adni.loni.usc.edu/data-samples/access-data/).The Alzheimer's disease dataset served as the target domain data, while the ADNI dataset was employed as the source domain data for transfer learning in pre-training the GANs.Specifically, The Alzheimer's disease dataset consists of four classes of MRI images in both the training and testing sets, including mild demented, moderate demented, non demented, and very mild demented.We aim to train a general deep learning model via federated learning to be applied in Alzheimer's disease classification tasks.The ADNI MRI image dataset we utilized comprises brain MRI scans from Alzheimer's disease (AD) patients, Mild Cognitive Impairment (MCI) patients, and normal elderly individuals.These images provide detailed information about brain structure, morphology, and pathology.We will employ this dataset to pretrain a WGAN-GP on client-side.

Evaluation
The performance of our algorithm is primarily evaluated based on the classification accuracy.Additionally, we utilize the maximum mean difference (MMD) to quantify the distribution discrepancy be-tween the generated virtual data and the target domain dataset.The squared MMD between two data distributions can be mathematically expressed as: Where (•) denotes the mapping to the regenerated Hilbert space (RKHS).

Results
In CFL_DS_KD, it is essential to ensure that GANs trained on the client's local datasets can adequately learn the local distribution knowledge and generate high-quality virtual datasets.Due to the small and heterogeneous nature of the local client datasets, it is challenging for the local GANs to fully capture the local distribution knowledge.Therefore, we employ transfer learning to allow GANs to initially learn distribution knowledge from the source domain before fine-tuning them on the local datasets.Fig. 4 demonstrates a comparison of medical images generated using GAN models with and without transfer learning.Obviously, we can find that the GAN using transfer learning is better than the original GAN in terms of clarity, contour, texture, etc. of the generated data.Furthermore, we can measure the quality of the data generated by the generator by calculating the MMD value between the generated data and the target domain data.As shown in the Table 1, the first row of the table presents the client MMD values measured under different non-iid scenarios, while the second row represents the MMD between the data generated by GAN with transfer learning and the target domain data under different non-iid scenarios.The third row shows the MMD between the data generated by GAN without pre-training and the target domain data under different non-iid scenarios.It can be observed that the MMD between the data distribution of the GAN generated through transfer learning and the data distribution of the target domain is smaller.This indicates that the data generated by the GAN with transfer learning is more similar to the target domain data, thus better reflecting the client's data distribution.
The aim of improving the learning performance of client's local distribution is to enable the clients in cyclic federated learning to share their respective real distributions.Therefore, we further evaluate the im-proved strategy from the algorithmic perspective to investigate the impact of transfer learning-based generators on algorithm performance under different client distribution disparities.As Fig. 5 and Table 2 show, it can be observed that when the MMD is below 0.5, indicating that the client data distributions are very similar, the GANs trained with transfer learning have a negative impact on algorithm performance.When the MMD is be-tween 0.5 and 1.2, indicating that there are some differences in client data distributions but not significant, both methods show similar performance, with a slight advantage for the transfer learning approach.However, when the MMD is greater than 1.2, indicating significant differences in client data distributions, transfer learning shows a noticeable improvement in algorithm performance.Fig. 6 specifically demonstrate the impact of transfer learningbased GAN and non-transfer learning-based GAN on the performance of CFL_DS_KD under different MMD values.It can be observed that the performance of the transfer learning-based GAN is better than the nontransfer learning-based GAN, and this effect becomes more prominent as the MMD increases.When the client distribution disparities are small, the transfer learningbased generator does not provide an advantage.This could be due to the fact that after the GAN learns knowledge from the source domain, fine-tuning on the target domain does not enable the model to adapt well to the target domain distribution, resulting in the model parameters being biased towards the source domain and leading to a deterioration in performance.Furthermore, to further highlight the advantages of using transfer learning-based GAN, we compared it with other algorithms as shown in the box plot in Fig. 7, where the MMD increases from the top left corner to the bottom right corner.By dynamically increasing the MMD, we can observe that as the MMD increases, indicating more inconsistent data distributions of clients, the performance of the cyclic federated averaging model (CFL_FedAvg) declines rapidly.The non-transfer learning-based GAN per-forms at an intermediate level, while the transfer learning-based GAN exhibits the best and most stable performance.Fig. 8 presents a performance comparison of different methods under different communication rounds.It can also be observed that the transfer learning-based GAN achieves the greatest improvement in the algorithm, and its performance is on par with or even surpasses centralized learning methods.

CONCLUSION
In general, this work focuses on enhancing the learning performance of the client's local data distribution.To overcome the challenges posed by data scarcity and heterogeneous distributions among clients' datasets, we propose the utilization of trans-fer learning to assist GANs in better capturing the underlying data distributions of the clients.Subsequently, the adequately trained GANs are applied within the framework of cyclic federated learning, which incorporates distribution information sharing and knowledge distillation.Through rigorous experimentation and evaluation, we provide evidence of the effectiveness of transfer learning in improving the performance of GANs in learning the client's data distribution, thereby enhancing the overall algorithmic performance.

Fig. 1 .
Fig. 1.Improving the performance of GAN in learning client distributions through transfer learning.
Once we have acquired a well-trained GAN model that effectively captures the data distribution of the client through transfer learning, we can proceed to integrate it into the cyclic federated learning method, which relies on the sharing of distribution information and knowledge distillation.The specific steps are as follows: Let  represent the total number of clients participating in the federated learning task.Let   = {  | = 1,2 … ,   } be the local dataset of client  (where  = 1,2, … , ) and   = |  | be the number of samples in the local dataset.Initially, the client c trains a generator through transfer learning which reflects the distribution information   of local datasets   .Thus,  clients are trained to obtain  generator models.Then, client  transmit its generator   to the client  + 1, forming a ring-shaped communication link when  =  let  + 1 = .Then, the generator   from the client  can generate ′ +1 virtually shared data points, i.e., ′ +1 = {  |  = (  ),  = 1,2, … , ′ +1 }.The distribution information sharing process is schematically illustrated in Fig. 2. Additionally, client  transmits local pretrained model   to the client  + 1.

Fig. 2 .
Fig. 2. The process of client distribution sharing and virtual dataset generation in cyclic federated learning.
Fig. 4. (a).An image sample of the target domain dataset.(b).An image sample generated by the GAN generator without transfer learning.(c).An image sample generated by the GAN generator with transfer learning.

Fig. 5 .
Fig. 5.The influence of transfer learning GAN and nontransfer learning GAN on performance of CFL_DS_KD was evaluated at different MMD levels.

Fig. 7 .
Fig. 7. Comparison of accuracy box plots for different methods at MMD

Table 1 .
Distribution discrepancy between generated data and target domain data under different non-iid scenarios.

Table 2 .
The influence of transfer learning-based GAN and non-transfer learning-based GAN on performance of CFL_DS_KD was evaluated at different MMD levels.