DIF-Gaussian 代码讲解

这篇论文的标题是《Learning 3D Gaussians for Extremely Sparse-View Cone-Beam CT Reconstruction》,作者是Yiqun Lin, Hualiang Wang, Jixiang Chen和Xiaomeng Li,来自香港科技大学以及HKUST深圳-香港协同创新研究院。

这篇论文主要探讨了一种新的锥束计算机断层扫描(CBCT)重建框架,称为DIF-Gaussian,旨在通过使用更少的投影来减少辐射剂量,同时提高重建图像的质量。

给的代码只是个框架,强行复现花费时间而且以我水平容易误人子弟,就简单的对照论文理解一下,大家有兴趣可以一起讨论

项目地址:

GitHub - xmed-lab/DIF-Gaussian: MICCAI 2024: Learning 3D Gaussians for Extremely Sparse-View Cone-Beam CT Reconstruction

数据预处理地址
https://github.com/xmed-lab/C2RV-CBCT/tree/main/data

1、 下载代码和数据预处理方法,数据放到data中

2、发现代码是不完整的,因此边补充边写

train.py

使其与不同版本的DDP兼容

    if args.dist:
        args.local_rank = int(os.environ["LOCAL_RANK"]) # Make it compatible with different versions of DDP
        torch.distributed.init_process_group(backend="nccl")
        torch.cuda.set_device(args.local_rank)

加载cfg,项目只给出了一个default.yaml,复制一个改个名字

    cfg = load_config(args.cfg_path)
    if args.local_rank == 0:
        print(args)
        print(cfg)

        # save config
        save_dir = f'./logs/{args.name}'
        os.makedirs(save_dir, exist_ok=True)
        if os.path.exists(os.path.join(save_dir, 'config.yaml')):
            time_str = datetime.now().strftime('%d-%m-%Y_%H-%M-%S')
            shutil.copyfile(
                os.path.join(save_dir, 'config.yaml'), 
                os.path.join(save_dir, f'config_{time_str}.yaml')
            )
        shutil.copyfile(args.cfg_path, os.path.join(save_dir, 'config.yaml'))

初始化训练数据集/加载器

    train_dst = CBCT_dataset_gs(
        dst_name=args.dst_name,
        cfg=cfg.dataset,
        split='train', 
        num_views=args.num_views, 
        npoint=args.num_points,
        out_res_scale=args.out_res_scale,
        random_views=args.random_views
    )

关键在于并没有数据,因此还得自己想办法

dataset:
  root_dir: ../../datasets
  gs_res: 12 # the resolution of GS points (12^3 points in total)

进去看看数据集如何构建

class CBCT_dataset_gs(CBCT_dataset):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        gs_res = self.cfg.gs_res
        points_gs = np.mgrid[:gs_res, :gs_res, :gs_res] / gs_res
        self.points_gs = points_gs.reshape(3, -1).transpose(1, 0) # ~[0, 1]

    def __getitem__(self, index):
        data_dict = super().__getitem__(index)

        # projections of GS points (initial center xyz)
        points_gs = deepcopy(self.points_gs)
        points_gs_proj = self.project_points(points_gs, data_dict['angles'])

        data_dict.update({
            'points_gs': points_gs,          # [K, 3]
            'points_gs_proj': points_gs_proj # [M, K, 2]
        })
        return data_dict

np.mgrid是NumPy库中的一个函数,它返回一个由给定尺寸的数组创建的多维网格。这段代码points_gs = np.mgrid[:gs_res, :gs_res, :gs_res] / gs_res创建了一个3D网格,并且将这个网格的每个点归一化到[0, 1]区间。

结果points_gs是一个4D数组,其形状为(gs_res, gs_res, gs_res, 3),其中最后一个维度包含每个网格点的x、y、z坐标。

 看getitem 

points_gs_proj = self.project_points(points_gs, data_dict['angles'])

points_gs 是一个3D网格的点,通常是用于表示3D空间中的一个体素化网格或者用于定义3D空间中的高斯分布的中心点。而 points_gs_proj 则是这些点在2D平面上的投影。

代码是不全的,后期再看看会不会更新

看LUNA16数据预处理的config 内有dataset的参数,其中的angle 为180

get返回一个3d 高斯网格,一个2d的投影

loader如下

    train_sampler = None
    if args.dist:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dst)
    train_loader = DataLoader(
        train_dst, 
        batch_size=args.batch_size, 
        sampler=train_sampler, 
        shuffle=(train_sampler is None),
        num_workers=0, # args.num_workers,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )

    # -- initialize evaluation dataset/loader
    eval_loader = DataLoader(
        CBCT_dataset_gs(
            dst_name=args.dst_name,
            cfg=cfg.dataset,
            split='eval',
            num_views=args.num_views,
            out_res_scale=0.5, # low-res for faster evaluation,
        ), 
        batch_size=1, 
        shuffle=False,
        pin_memory=True
    )

加载模型,模型放到后面看


    # -- initialize model
    model = DIF_Gaussian(cfg.model)
    if args.resume:
        print(f'resume model from epoch {args.resume}')
        ckpt = torch.load(
            os.path.join(f'./logs/{args.name}/ep_{args.resume}.pth'),
            map_location=torch.device('cpu')
        )
        model.load_state_dict(ckpt)
    
    model = model.cuda()
    if args.dist:
        model = nn.parallel.DistributedDataParallel(
            model, 
            find_unused_parameters=False,
            device_ids=[args.local_rank]
        )

优化器和优化器规划,损失只有一个MSE

    # -- initialize optimizer, lr scheduler, and loss function
    optimizer = torch.optim.SGD(
        model.parameters(), 
        lr=args.lr, 
        momentum=0.98, 
        weight_decay=args.weight_decay
    )
    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, 
        step_size=1, 
        gamma=np.power(args.lr_decay, 1 / args.epoch)
    )
    loss_func = nn.MSELoss()

开始训练

    # -- training starts
    for epoch in range(start_epoch, args.epoch + 1):
        if args.dist:
            train_loader.sampler.set_epoch(epoch)

        loss_list = []
        model.train()
        optimizer.zero_grad()

一个epoch,外部看没有花里胡哨的损失,一个损失做到底

        for k, item in enumerate(train_loader):
            item = convert_cuda(item)

            pred = model(item)
            loss = loss_func(pred['points_pred'], item['points_gt'])
            loss_list.append(loss.item())

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

评估和优化

        if args.local_rank == 0:
            if epoch % 10 == 0:
                loss = np.mean(loss_list)
                print('epoch: {}, loss: {:.4}'.format(epoch, loss))
            
            if epoch % 100 == 0 or (epoch >= (args.epoch - 100) and epoch % 10 == 0):
                if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
                    model_state = model.module.state_dict()
                else:
                    model_state = model.state_dict()
                torch.save(
                    model_state,
                    os.path.join(save_dir, f'ep_{epoch}.pth')
                )

            if epoch % 50 == 0 or (epoch >= (args.epoch - 100) and epoch % 20 == 0):
                metrics, _ = eval_one_epoch(
                    model, 
                    eval_loader, 
                    args.eval_npoint,
                    ignore_msg=True,
                )
                msg = f' --- epoch {epoch}'
                for dst_name in metrics.keys():
                    msg += f', {dst_name}'
                    met = metrics[dst_name]
                    for key, val in met.items():
                        msg += ', {}: {:.4}'.format(key, val)
                print(msg)
        
        if lr_scheduler is not None:
            lr_scheduler.step()

model .py

看看初始化定义了什么

class DIF_Gaussian(Recon_base):
    def __init__(self, cfg):
        super().__init__(cfg)

    def init(self):
        self.init_encoder()
        
        # gaussians-related modules
        mid_ch = self.image_encoder.out_ch
        ds_ch = self.image_encoder.ds_ch
        self.gs_feats_mlp = MLP_1d([ds_ch, ds_ch // 4, mid_ch], use_bn=True, last_bn=True, last_act=False)
        self.gs_params_mlp = MLP_1d([ds_ch, ds_ch // 4, 3 + 4 + 3], use_bn=True, last_bn=False, last_act=False) # 3d: offsets, 4d: rotation, 3d: scaling
        self.gs_act = nn.LeakyReLU(inplace=True)

        self.init_decoder(mid_ch * 2)
        self.registered_point_keys = ['points', 'points_proj']

初始化编码器:self.init_encoder()

定义高斯特征和参数mlp:self.gs_feats_mlp;self.gs_params_mlp,选用线性激活self.gs_act

初始化解码器 

虽然没写完全,但是不难想象编码器和解码器的都是unet里面的

看向里面的点forward ,获取点的预测值

1多视图像素对齐功能+最大池

2gaussian-based插值函数

3逐点地预测

class PointDecoder(nn.Module):
    def __init__(self, channels, residual=True, use_bn=True):
        super().__init__()

        self.residual = residual
        self.mlps = nn.ModuleList()

        for i in range(len(channels) - 1):
            modules = []
            if i == 0 or not self.residual:
                modules.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=1))
            else:
                modules.append(nn.Conv1d(channels[i] + channels[0], channels[i + 1], kernel_size=1))

            if i != len(channels) - 1:
                if use_bn:
                    modules.append(nn.BatchNorm1d(channels[i + 1]))
                modules.append(nn.LeakyReLU(inplace=True))

            self.mlps.append(nn.Sequential(*modules))

    def forward(self, x):
        x_ = x
        for i, m in enumerate(self.mlps):
            if i != 0 and self.residual:
                x_ = torch.cat([x_, x], dim=1)
            x_ = m(x_)
        return x_

query_view_feats:应该是对应这个公式

def query_view_feats(view_feats, points_proj, fusion='max'):
    # view_feats: [B, M, C, H, W]
    # points_proj: [B, M, N, 2]
    # output: [B, C, N, M]
    n_view = view_feats.shape[1]
    p_feats_list = []
    for i in range(n_view):
        feat = view_feats[:, i, ...] # B, C, W, H
        p = points_proj[:, i, ...] # B, N, 2
        p_feats = index_2d(feat, p) # B, C, N
        p_feats_list.append(p_feats)
    p_feats = torch.stack(p_feats_list, dim=-1) # B, C, N, M
    if fusion == 'max':
        p_feats = F.max_pool2d(p_feats, (1, p_feats.shape[-1]))
        p_feats = p_feats.squeeze(-1) # [B, C, K]
    elif fusion is not None:
        raise NotImplementedError
    return p_feats

插值如下

下面有一个点decoder

class PointDecoder(nn.Module):
    def __init__(self, channels, residual=True, use_bn=True):
        super().__init__()

        self.residual = residual
        self.mlps = nn.ModuleList()

        for i in range(len(channels) - 1):
            modules = []
            if i == 0 or not self.residual:
                modules.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=1))
            else:
                modules.append(nn.Conv1d(channels[i] + channels[0], channels[i + 1], kernel_size=1))

            if i != len(channels) - 1:
                if use_bn:
                    modules.append(nn.BatchNorm1d(channels[i + 1]))
                modules.append(nn.LeakyReLU(inplace=True))

            self.mlps.append(nn.Sequential(*modules))

    def forward(self, x):
        x_ = x
        for i, m in enumerate(self.mlps):
            if i != 0 and self.residual:
                x_ = torch.cat([x_, x], dim=1)
            x_ = m(x_)
        return x_

用了残差网络进行预测

相关推荐

  1. 基本代码讲解

    2024-07-09 17:26:03       32 阅读

最近更新

  1. 在 Ubuntu 22.04/20.04 安装 CVAT 和 SAM 指南

    2024-07-09 17:26:03       0 阅读
  2. C++多线程编程中的锁详解

    2024-07-09 17:26:03       0 阅读
  3. 生成对抗网络(GAN):目标检测的新前沿

    2024-07-09 17:26:03       0 阅读
  4. 机器学习浅讲

    2024-07-09 17:26:03       0 阅读
  5. 动态内存规划

    2024-07-09 17:26:03       0 阅读
  6. js之深入对象和内置构造函数

    2024-07-09 17:26:03       0 阅读
  7. k8s安装powerjob

    2024-07-09 17:26:03       0 阅读

热门阅读

  1. mybatis用注解替换xml,不再写.xml了

    2024-07-09 17:26:03       8 阅读
  2. Docker

    Docker

    2024-07-09 17:26:03      7 阅读
  3. 服务器安装多个Tomcat

    2024-07-09 17:26:03       7 阅读
  4. 玩转springboot之springboot定制化tomcat

    2024-07-09 17:26:03       10 阅读
  5. Word使用中的一些烦人的小问题

    2024-07-09 17:26:03       7 阅读
  6. Redis 中的跳跃表是什么

    2024-07-09 17:26:03       9 阅读
  7. 大语言模型系列-Transformer介绍

    2024-07-09 17:26:03       6 阅读
  8. FCA-FineReport认证试题及答案

    2024-07-09 17:26:03       8 阅读
  9. Windows 中修改 MySQL 密码

    2024-07-09 17:26:03       5 阅读