CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Dy2St] Clean unused inputs and outputs for backward #66278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dy2St] Clean unused inputs and outputs for backward #66278
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
现象CSWinTransformer 在 SOT+PIR 下,V100 16G 机器上会 OOM,其他几种模式包括 SOT+PT、AST+PT、AST+PIR、动态图峰值显存全都是 13G 左右,而 SOT+PIR 则会 OOM SOT+PIR 第一个 step 的前向都没跑完就 OOM 了 从代码上来看,首先 SOT 代码在 PIR 和 PT 基本一致,基本不可能是 SOT 多 hold 什么 Tensor,但 AST 那层也就是 partial_program -> run_program OP 这边代码就不一致了,大概率问题出在这里 分析组网分析从可比性来讲,SOT+PT 和 SOT+PIR 下差异是最小的,只要让 SOT+PIR 对齐 SOT+PT 的策略就能保证显存对齐,因此如下分析都针对于两者 子图分析在那之前,我们先看看显存要如何分析
我们的日志里存在如上的内存分配 log,因此可以通过这个来分析子图的显存问题,首先收集 log # 手动将上面代码 VLOG(10) 改成了 VLOG(7),因为 GLOG_v=10 PT 会挂,暂不分析
# 当然也可以用 GLOG_vmodule=auto_growth_best_fit_allocator=10 来只收集显存信息,前期只分析显存可以这样,后期因为需要分析周围 log,所以都需要 dump
# 设置 FLAGS_new_executor_sequential_run 确保顺序是可比的
GLOG_v=7 FLAGS_new_executor_sequential_run=true SOT_LOG_LEVEL=0 ENABLE_FALL_BACK=True MIN_GRAPH_SIZE=0 FLAGS_enable_pir_api=1 python tools/train.py -c ppcls/configs/ImageNet/CSWinTransformer/CSWinTransformer_base_384.yaml \
-o Global.epochs=1 \
-o Global.save_interval=1 \
-o Global.eval_interval=1 \
-o Global.seed=1234 \
-o DataLoader.Train.dataset.image_root=/workspace/PaddleClas/dataset/ILSVRC2012/ \
-o DataLoader.Train.dataset.cls_label_path=/workspace/PaddleClas/dataset/ILSVRC2012/train_list.txt \
-o DataLoader.Train.sampler.batch_size=8 \
-o DataLoader.Eval.dataset.image_root=/workspace/PaddleClas/dataset/ILSVRC2012/ \
-o DataLoader.Eval.dataset.cls_label_path=/workspace/PaddleClas/dataset/ILSVRC2012/val_list.txt \
-o DataLoader.Eval.sampler.batch_size=8 \
-o DataLoader.Train.loader.num_workers=0 \
-o DataLoader.Train.sampler.shuffle=False \
-o Global.output_dir=output/ppcls/configs/ImageNet/CSWinTransformer/CSWinTransformer_base_384 \
-o Global.to_static=True >! oom-pir-alloc.log 2>&1
GLOG_v=7 FLAGS_new_executor_sequential_run=true SOT_LOG_LEVEL=0 ENABLE_FALL_BACK=True MIN_GRAPH_SIZE=0 FLAGS_enable_pir_api=0 python tools/train.py -c ppcls/configs/ImageNet/CSWinTransformer/CSWinTransformer_base_384.yaml \
-o Global.epochs=1 \
-o Global.save_interval=1 \
-o Global.eval_interval=1 \
-o Global.seed=1234 \
-o DataLoader.Train.dataset.image_root=/workspace/PaddleClas/dataset/ILSVRC2012/ \
-o DataLoader.Train.dataset.cls_label_path=/workspace/PaddleClas/dataset/ILSVRC2012/train_list.txt \
-o DataLoader.Train.sampler.batch_size=8 \
-o DataLoader.Eval.dataset.image_root=/workspace/PaddleClas/dataset/ILSVRC2012/ \
-o DataLoader.Eval.dataset.cls_label_path=/workspace/PaddleClas/dataset/ILSVRC2012/val_list.txt \
-o DataLoader.Eval.sampler.batch_size=8 \
-o DataLoader.Train.loader.num_workers=0 \
-o DataLoader.Train.sampler.shuffle=False \
-o Global.output_dir=output/ppcls/configs/ImageNet/CSWinTransformer/CSWinTransformer_base_384 \
-o Global.to_static=True >! oom-pt-alloc.log 2>&1 dump 日志,编写简单脚本快速分析显存问题: from __future__ import annotations
from dataclasses import dataclass
from typing import Callable
import sys
import re
LOG_PATH_TEMPLATE = "oom-{}-alloc.log"
# LOG_PATH = "oom-pt-alloc.log"
# LOG_PATH = "oom-pir-alloc.log"
class MemoryInfoItem:
def __init__(self, ptr: str, size: int):
self.ptr = ptr
self.size = size
def format_bytes(size: int) -> str:
units = ["B", "KB", "MB", "GB", "TB"]
unit_idx = 0
sign = "" if size >= 0 else "-"
size = abs(size)
while size >= 1024 and unit_idx < len(units):
size /= 1024
unit_idx += 1
return f"{sign}{size:.2f} {units[unit_idx]}"
class AllocMemoryInfoItem(MemoryInfoItem):
LOG_MATCH_REGEX = re.compile(r".+Alloc (?P<size>\d+) bytes, ptr = (?P<ptr>0x\w+)")
def __init__(self, ptr: str, size: int):
super().__init__(ptr, size)
def __repr__(self):
return f"AllocMemoryInfoItem(ptr={self.ptr}, size={self.size})"
class FreeMemoryInfoItem(MemoryInfoItem):
LOG_MATCH_REGEX = re.compile(r".+Free (?P<size>\d+) bytes, ptr = (?P<ptr>0x\w+)")
def __init__(self, ptr: str, size: int):
super().__init__(ptr, size)
def __repr__(self):
return f"FreeMemoryInfoItem(ptr={self.ptr}, size={self.size})"
def extract_memory_info(logs: list[str], start_fn: Callable[[str], bool], end_fn: Callable[[str], bool]) -> list[MemoryInfoItem]:
memory_info = []
started = False
for line in logs:
if not started and not start_fn(line):
continue
started = True
if end_fn(line):
break
if match_obj := AllocMemoryInfoItem.LOG_MATCH_REGEX.match(line):
memory_info.append(AllocMemoryInfoItem(match_obj.group("ptr"), int(match_obj.group("size"))))
elif match_obj := FreeMemoryInfoItem.LOG_MATCH_REGEX.match(line):
memory_info.append(FreeMemoryInfoItem(match_obj.group("ptr"), int(match_obj.group("size"))))
return memory_info
def read_log(log_path: str) -> list[str]:
with open(log_path, "r") as f:
lines = f.readlines()
return lines
class MemoryAnalyzer:
def __call__(self, memory_info: list[MemoryInfoItem]):
...
class RemainingMemoryAnalyzer(MemoryAnalyzer):
def __init__(self):
self.remaining_memory = 0
self.max_memory = 0
def __call__(self, memory_info: list[MemoryInfoItem]):
remaining_memory = 0
for item in memory_info:
if isinstance(item, AllocMemoryInfoItem):
remaining_memory += item.size
elif isinstance(item, FreeMemoryInfoItem):
remaining_memory -= item.size
if remaining_memory > self.max_memory:
self.max_memory = remaining_memory
self.remaining_memory = remaining_memory
def summary(self):
print(f"Remaining memory: {format_bytes(self.remaining_memory)}")
print(f"Max memory: {format_bytes(self.max_memory)}")
class AllocFreeStatAnalyzer(MemoryAnalyzer):
def __init__(self):
self.alloc_size = 0
self.free_size = 0
self.alloc_count = 0
self.free_count = 0
def __call__(self, memory_info: list[MemoryInfoItem]):
alloc_size = 0
free_size = 0
alloc_count = 0
free_count = 0
for item in memory_info:
if isinstance(item, AllocMemoryInfoItem):
alloc_size += item.size
alloc_count += 1
elif isinstance(item, FreeMemoryInfoItem):
free_size += item.size
free_count += 1
self.alloc_size = alloc_size
self.free_size = free_size
self.alloc_count = alloc_count
self.free_count = free_count
def summary(self):
print(f"Allocated memory: {format_bytes(self.alloc_size)}")
print(f"Allocated count: {self.alloc_count}")
print(f"Freed memory: {format_bytes(self.free_size)}")
print(f"Freed count: {self.free_count}")
def analyse_memory_info(memory_info: list[MemoryInfoItem]):
# remaining_memory
remaining_memory_analyzer = RemainingMemoryAnalyzer()
remaining_memory_analyzer(memory_info)
remaining_memory_analyzer.summary()
# alloc free stat
alloc_free_stat_analyzer = AllocFreeStatAnalyzer()
alloc_free_stat_analyzer(memory_info)
alloc_free_stat_analyzer.summary()
logs = read_log(LOG_PATH_TEMPLATE.format(sys.argv[1]))
memory_info = extract_memory_info(
logs,
lambda line: "START SOT_CALL_0" in line,# or "EPOCH START" in line,
lambda line: "EPOCH END" in line or "Traceback" in line or "START SOT_CALL_429" in line
)
# print(memory_info)
analyse_memory_info(memory_info) 当然,为了能够让 log 对齐,我们在 log 中插入一些「锚点」,利用 通过该脚本,我们可以很容易分析得到各个子图的 log 分析结果 nyakku@localhost /workspace/PaddleClas develop* ⇣
paddle-py310 ❯ python memory-analyzer.py pir
Remaining memory: 13.08 GB
Max memory: 13.11 GB
Allocated memory: 21.67 GB
Allocated count: 3791
Freed memory: 8.59 GB
Freed count: 1942
nyakku@localhost /workspace/PaddleClas develop* ⇣
paddle-py310 ❯ python memory-analyzer.py pt
Remaining memory: 9.37 GB
Max memory: 9.41 GB
Allocated memory: 21.67 GB
Allocated count: 2939
Freed memory: 12.30 GB
Freed count: 1841 比如这里分析的是 可以看到几点信息:
由此可以得知,问题主要是 SOT 前向时候显存释放问题,进一步缩小范围查看具体是哪个子图的问题 结果发现几乎所有子图都有问题…… 那就分析一个较小子图,比如 nyakku@localhost /workspace/PaddleClas develop* ⇣
paddle-py310 ❯ python memory-analyzer.py pir
Remaining memory: 6.75 MB
Max memory: 10.12 MB
Allocated memory: 10.12 MB
Allocated count: 3
Freed memory: 3.38 MB
Freed count: 1
nyakku@localhost /workspace/PaddleClas develop* ⇣
paddle-py310 ❯ python memory-analyzer.py pt
Remaining memory: 3.38 MB
Max memory: 10.12 MB
Allocated memory: 10.12 MB
Allocated count: 3
Freed memory: 6.75 MB
Freed count: 2 明显 PT 多了一次释放,编写脚本裁剪 组网很简单,比如 PIR 下前向组网如下:
对比两者日志,可以很容易发现 PIR 缺少的 Free
缺少的 Free 出现在出 通过排查发现这个 Tensor 是输入,但不完全是用户代码里 hold 的那个输入,因为用户的输入是 stride 的,因此传入动转静后会先转连续,申请一块新的内存,所以这个输入在用户代码侧是不会 hold 住的,因此出了作用域就释放掉了 那 PIR 为啥没释放掉呢?通过查看这个 holder 的引用计数( 所以问题还是出在 RunProgramAPI 时部分 Tensor 被 hold 在了 scope,可以通过打印 这是因为 PIR 的算法是 通过查看 那么 PIR 下 修改后,不再 OOM 了~通过脚本分析,显存完全对齐~ nyakku@localhost /workspace/PaddleClas develop* ⇣
paddle-py310 ❯ python memory-analyzer.py pir
Remaining memory: 11.88 GB
Max memory: 11.92 GB
Allocated memory: 21.67 GB
Allocated count: 3791
Freed memory: 9.79 GB
Freed count: 2201
nyakku@localhost /workspace/PaddleClas develop* ⇣
paddle-py310 ❯ python memory-analyzer.py pt
Remaining memory: 11.88 GB
Max memory: 11.92 GB
Allocated memory: 21.67 GB
Allocated count: 2939
Freed memory: 9.79 GB
Freed count: 1349 但是峰值显存仍然比 PT 高一些,16G 擦边过,多一点就跑不了了,不过因为之前对比源码已经发现了 PIR 少了一段清理 scope 里反向用不到的前向输出的逻辑,PT 下注释掉该逻辑可以复现 16G 的问题(上面的日志就是注释掉之后的),因此可以确定是该段逻辑导致的,补齐该段逻辑后显存一切正常~也是 13G~ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
另外记一个 TODO,对于 name 获取的统一,我们现在在
pir.cc
(pybind 层前反向拆分等逻辑)pir_partial_program.py
run_program_op_node.h
各有一套 name 获取的逻辑,但是它们是能够统一的,比如 RunProgramOP
这里现在其实是遍历整个 Program,这其实是非常耗时的,事实上只需要从上下游 OP 上获取即可,对于未来极致优化性能来说是一个潜在的优化点,这件事情之后会由 @gouzil 来推进
// *backward_program); | ||
|
||
// Step 3. get all eager gc vars (skip_names = backward_inputs - | ||
// no_need_buffers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// no_need_buffers) | |
// no_need_buffers + outputs) |
这里之后 PR 改一下
PR Category
Execute Infrastructure
PR Types
Performance
Description
修复 SOT+PIR 下 OOM 的问题,主要是清理掉反向用不到的前向输入和前向输出,使之前向结束后不再 hold 在 scope 中,详细分析见 comment
PCard-66972