<p align="center">
  <h2 align="center">PCoTTA: Continual Test-Time Adaptation for Multi-Task Point Cloud Understanding</h2>
</p>
  <p align="center">
    NeurIPS 2024 Submission
    <br>
    Paper ID #57
  </p>



<div  align="center">    
 <img src="./figures/overview.pdf" width = 1000  align=center />
</div>
Our **PCoTTA**. It addresses continually changing targets by using their nearest source sample as a prompt for multi-task learning within a unified model. We introduce Gaussian Splatted Feature Shifting (GSFS) to align unknown targets with sources, improving transferability. Source prototypes from different domains and learnable prototypes form a prototype bank. The Automatic Prototype Mixture (APM) pairs these prototypes based on the similarity to the target, preventing catastrophic forgetting. We project these prototypes as Gaussian distributions onto the feature plane, with larger weights assigned to more relevant ones. Our graph attention updates these weights dynamically to mitigate error accumulation. Additionally, our Contrastive Prototype Repulsion (CPR) ensures that learnable prototypes are distinguishable for different targets, enhancing adaptability. 


# Highlights
1. We present PCoTTA, an innovative, pioneering, and unified framework for Continual Test-Time Adaptation (CoTTA) in **multi-task point cloud understanding**, enhancing the model's transferability towards the **continually changing target domain**. We introduce a multi-task setting with a new benchmark for PCoTTA, which is practical and realistic in the real world.
2. We devise three innovative modules for PCoTTA, \emph{i.e.,} **automatic prototype mixture (APM), Gaussian Splatted feature shifting (GSFS), and contrastive prototype repulsion (CPR)** strategies, where APM avoids straying too far from its original source model, mitigating the risk of catastrophic forgetting, and GSFS dynamically shifts the testing sample toward the source model, alleviating error accumulation, and CPR pulls the nearest learnable prototype close to the testing feature and pushes it away from other prototypes.
3. Extensive experimental results with analysis demonstrate the effectiveness and superiority of our presented method, surpassing the state-of-the-art approaches by **a large margin**.

# Abstract


In this paper, we present PCoTTA, an innovative, pioneering framework for Continual Test-Time Adaptation (CoTTA) in multi-task point cloud understanding, enhancing the model's transferability towards the continually changing target domain. We introduce a multi-task setting for PCoTTA, which is practical and realistic, handling multiple tasks within one unified model during the continual adaptation. Our PCoTTA involves three key components: automatic prototype mixture (APM), Gaussian Splatted feature shifting (GSFS), and contrastive prototype repulsion (CPR). Firstly, APM is designed to automatically mix the source prototypes with the learnable prototypes with a  similarity balancing factor, avoiding catastrophic forgetting. Then, GSFS dynamically shifts the testing sample toward the source domain, mitigating error accumulation in an online manner. In addition, CPR is proposed to pull the nearest learnable prototype close to the testing feature and push it away from other prototypes, making each prototype distinguishable during the adaptation. Experimental comparisons lead to a new benchmark,  demonstrating PCoTTA's superiority in boosting the model's transferability towards the continually changing target domain.


# Implement

## 1. Requirements
Recommend version:
```
PyTorch = 2.0.0;
python = 3.10;
CUDA = 12.1;
```

Other packages:
```
pip install -r requirements.txt
```

Install Pyorch3d:
```
git clone https://github.com/facebookresearch/pytorch3d.git
cd pytorch3d
export CUB_HOME=/usr/local/cuda/include/
FORCE_CUDA=1 python setup.py install
```

Install extension for Chamfer Distance:
```
cd ./extensions/chamfer_dist
python setup.py install
```

## 2. Pre-training
To pre-train our PCoTTA on the multiple domains on the **multi-task** setting, run the following command:

```
python main.py --config cfgs/PCoTTA_MN_SO.yaml --exp_name exp/PCoTTA_MN_SO
```
We use ModelNet (MN) and ScanObjectNN (SO) as target domains by default, which are not available during pre-training. Replacing `[targe_domain]` in `.yaml` can change the target domains.


## 3. Continual Test-Time Adaption

To obtain the performance of the continually changing target domains on 3 different tasks through our **PCoTTA** , run the following command:

```
python test_PCoTTA.py --config cfgs/PCoTTA_MN_SO.yaml --exp_name exp_test/PCoTTA_MN_SO--ckpts experiments/PCoTTA_MN_SO/ckpt-last.pth
```

# Visual Results
<div  align="center">    
 <img src="./figures/visual_results.pdf" width = 1000  align=center />
</div>