mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-05-02 13:36:28 +02:00
Compare commits
6 Commits
b8f5437b75
...
545dbd5426
Author | SHA1 | Date | |
---|---|---|---|
|
545dbd5426 | ||
|
592fd5daf8 | ||
|
c9353aba6c | ||
|
6bb22e0c15 | ||
|
de7df86119 | ||
|
70ff909fdc |
215
CITATION.cff
215
CITATION.cff
@ -1,215 +0,0 @@
|
||||
cff-version: 1.2.0
|
||||
message: "If you use this work, please cite it using the following metadata."
|
||||
title: "DeepSeek-V3 Technical Report"
|
||||
authors:
|
||||
- name: "DeepSeek-AI"
|
||||
- name: "Aixin Liu"
|
||||
- name: "Bei Feng"
|
||||
- name: "Bing Xue"
|
||||
- name: "Bingxuan Wang"
|
||||
- name: "Bochao Wu"
|
||||
- name: "Chengda Lu"
|
||||
- name: "Chenggang Zhao"
|
||||
- name: "Chengqi Deng"
|
||||
- name: "Chenyu Zhang"
|
||||
- name: "Chong Ruan"
|
||||
- name: "Damai Dai"
|
||||
- name: "Daya Guo"
|
||||
- name: "Dejian Yang"
|
||||
- name: "Deli Chen"
|
||||
- name: "Dongjie Ji"
|
||||
- name: "Erhang Li"
|
||||
- name: "Fangyun Lin"
|
||||
- name: "Fucong Dai"
|
||||
- name: "Fuli Luo"
|
||||
- name: "Guangbo Hao"
|
||||
- name: "Guanting Chen"
|
||||
- name: "Guowei Li"
|
||||
- name: "H. Zhang"
|
||||
- name: "Han Bao"
|
||||
- name: "Hanwei Xu"
|
||||
- name: "Haocheng Wang"
|
||||
- name: "Haowei Zhang"
|
||||
- name: "Honghui Ding"
|
||||
- name: "Huajian Xin"
|
||||
- name: "Huazuo Gao"
|
||||
- name: "Hui Li"
|
||||
- name: "Hui Qu"
|
||||
- name: "J. L. Cai"
|
||||
- name: "Jian Liang"
|
||||
- name: "Jianzhong Guo"
|
||||
- name: "Jiaqi Ni"
|
||||
- name: "Jiashi Li"
|
||||
- name: "Jiawei Wang"
|
||||
- name: "Jin Chen"
|
||||
- name: "Jingchang Chen"
|
||||
- name: "Jingyang Yuan"
|
||||
- name: "Junjie Qiu"
|
||||
- name: "Junlong Li"
|
||||
- name: "Junxiao Song"
|
||||
- name: "Kai Dong"
|
||||
- name: "Kai Hu"
|
||||
- name: "Kaige Gao"
|
||||
- name: "Kang Guan"
|
||||
- name: "Kexin Huang"
|
||||
- name: "Kuai Yu"
|
||||
- name: "Lean Wang"
|
||||
- name: "Lecong Zhang"
|
||||
- name: "Lei Xu"
|
||||
- name: "Leyi Xia"
|
||||
- name: "Liang Zhao"
|
||||
- name: "Litong Wang"
|
||||
- name: "Liyue Zhang"
|
||||
- name: "Meng Li"
|
||||
- name: "Miaojun Wang"
|
||||
- name: "Mingchuan Zhang"
|
||||
- name: "Minghua Zhang"
|
||||
- name: "Minghui Tang"
|
||||
- name: "Mingming Li"
|
||||
- name: "Ning Tian"
|
||||
- name: "Panpan Huang"
|
||||
- name: "Peiyi Wang"
|
||||
- name: "Peng Zhang"
|
||||
- name: "Qiancheng Wang"
|
||||
- name: "Qihao Zhu"
|
||||
- name: "Qinyu Chen"
|
||||
- name: "Qiushi Du"
|
||||
- name: "R. J. Chen"
|
||||
- name: "R. L. Jin"
|
||||
- name: "Ruiqi Ge"
|
||||
- name: "Ruisong Zhang"
|
||||
- name: "Ruizhe Pan"
|
||||
- name: "Runji Wang"
|
||||
- name: "Runxin Xu"
|
||||
- name: "Ruoyu Zhang"
|
||||
- name: "Ruyi Chen"
|
||||
- name: "S. S. Li"
|
||||
- name: "Shanghao Lu"
|
||||
- name: "Shangyan Zhou"
|
||||
- name: "Shanhuang Chen"
|
||||
- name: "Shaoqing Wu"
|
||||
- name: "Shengfeng Ye"
|
||||
- name: "Shirong Ma"
|
||||
- name: "Shiyu Wang"
|
||||
- name: "Shuang Zhou"
|
||||
- name: "Shuiping Yu"
|
||||
- name: "Shunfeng Zhou"
|
||||
- name: "Shuting Pan"
|
||||
- name: "T. Wang"
|
||||
- name: "Tao Yun"
|
||||
- name: "Tian Pei"
|
||||
- name: "Tianyu Sun"
|
||||
- name: "W. L. Xiao"
|
||||
- name: "Wangding Zeng"
|
||||
- name: "Wanjia Zhao"
|
||||
- name: "Wei An"
|
||||
- name: "Wen Liu"
|
||||
- name: "Wenfeng Liang"
|
||||
- name: "Wenjun Gao"
|
||||
- name: "Wenqin Yu"
|
||||
- name: "Wentao Zhang"
|
||||
- name: "X. Q. Li"
|
||||
- name: "Xiangyue Jin"
|
||||
- name: "Xianzu Wang"
|
||||
- name: "Xiao Bi"
|
||||
- name: "Xiaodong Liu"
|
||||
- name: "Xiaohan Wang"
|
||||
- name: "Xiaojin Shen"
|
||||
- name: "Xiaokang Chen"
|
||||
- name: "Xiaokang Zhang"
|
||||
- name: "Xiaosha Chen"
|
||||
- name: "Xiaotao Nie"
|
||||
- name: "Xiaowen Sun"
|
||||
- name: "Xiaoxiang Wang"
|
||||
- name: "Xin Cheng"
|
||||
- name: "Xin Liu"
|
||||
- name: "Xin Xie"
|
||||
- name: "Xingchao Liu"
|
||||
- name: "Xingkai Yu"
|
||||
- name: "Xinnan Song"
|
||||
- name: "Xinxia Shan"
|
||||
- name: "Xinyi Zhou"
|
||||
- name: "Xinyu Yang"
|
||||
- name: "Xinyuan Li"
|
||||
- name: "Xuecheng Su"
|
||||
- name: "Xuheng Lin"
|
||||
- name: "Y. K. Li"
|
||||
- name: "Y. Q. Wang"
|
||||
- name: "Y. X. Wei"
|
||||
- name: "Y. X. Zhu"
|
||||
- name: "Yang Zhang"
|
||||
- name: "Yanhong Xu"
|
||||
- name: "Yanping Huang"
|
||||
- name: "Yao Li"
|
||||
- name: "Yao Zhao"
|
||||
- name: "Yaofeng Sun"
|
||||
- name: "Yaohui Li"
|
||||
- name: "Yaohui Wang"
|
||||
- name: "Yi Yu"
|
||||
- name: "Yi Zheng"
|
||||
- name: "Yichao Zhang"
|
||||
- name: "Yifan Shi"
|
||||
- name: "Yiliang Xiong"
|
||||
- name: "Ying He"
|
||||
- name: "Ying Tang"
|
||||
- name: "Yishi Piao"
|
||||
- name: "Yisong Wang"
|
||||
- name: "Yixuan Tan"
|
||||
- name: "Yiyang Ma"
|
||||
- name: "Yiyuan Liu"
|
||||
- name: "Yongqiang Guo"
|
||||
- name: "Yu Wu"
|
||||
- name: "Yuan Ou"
|
||||
- name: "Yuchen Zhu"
|
||||
- name: "Yuduan Wang"
|
||||
- name: "Yue Gong"
|
||||
- name: "Yuheng Zou"
|
||||
- name: "Yujia He"
|
||||
- name: "Yukun Zha"
|
||||
- name: "Yunfan Xiong"
|
||||
- name: "Yunxian Ma"
|
||||
- name: "Yuting Yan"
|
||||
- name: "Yuxiang Luo"
|
||||
- name: "Yuxiang You"
|
||||
- name: "Yuxuan Liu"
|
||||
- name: "Yuyang Zhou"
|
||||
- name: "Z. F. Wu"
|
||||
- name: "Z. Z. Ren"
|
||||
- name: "Zehui Ren"
|
||||
- name: "Zhangli Sha"
|
||||
- name: "Zhe Fu"
|
||||
- name: "Zhean Xu"
|
||||
- name: "Zhen Huang"
|
||||
- name: "Zhen Zhang"
|
||||
- name: "Zhenda Xie"
|
||||
- name: "Zhengyan Zhang"
|
||||
- name: "Zhewen Hao"
|
||||
- name: "Zhibin Gou"
|
||||
- name: "Zhicheng Ma"
|
||||
- name: "Zhigang Yan"
|
||||
- name: "Zhihong Shao"
|
||||
- name: "Zhipeng Xu"
|
||||
- name: "Zhiyu Wu"
|
||||
- name: "Zhongyu Zhang"
|
||||
- name: "Zhuoshu Li"
|
||||
- name: "Zihui Gu"
|
||||
- name: "Zijia Zhu"
|
||||
- name: "Zijun Liu"
|
||||
- name: "Zilin Li"
|
||||
- name: "Ziwei Xie"
|
||||
- name: "Ziyang Song"
|
||||
- name: "Ziyi Gao"
|
||||
- name: "Zizheng Pan"
|
||||
year: 2024
|
||||
identifiers:
|
||||
- type: doi
|
||||
value: 10.48550/arXiv.2412.19437
|
||||
- type: arXiv
|
||||
value: 2412.19437
|
||||
url: "https://arxiv.org/abs/2412.19437"
|
||||
categories:
|
||||
- "cs.CL"
|
||||
repository-code: "https://github.com/deepseek-ai/DeepSeek-V3"
|
||||
license: "MIT"
|
||||
abstract: >
|
||||
We present DeepSeek-V3, a strong Mixture-of-Experts (MoE) language model with 671B total parameters with 37B activated for each token. To achieve efficient inference and cost-effective training, DeepSeek-V3 adopts Multi-head Latent Attention (MLA) and DeepSeekMoE architectures, which were thoroughly validated in DeepSeek-V2. Furthermore, DeepSeek-V3 pioneers an auxiliary-loss-free strategy for load balancing and sets a multi-token prediction training objective for stronger performance. We pre-train DeepSeek-V3 on 14.8 trillion diverse and high-quality tokens, followed by Supervised Fine-Tuning and Reinforcement Learning stages to fully harness its capabilities. Comprehensive evaluations reveal that DeepSeek-V3 outperforms other open-source models and achieves performance comparable to leading closed-source models. Despite its excellent performance, DeepSeek-V3 requires only 2.788M H800 GPU hours for its full training. In addition, its training process is remarkably stable. Throughout the entire training process, we did not experience any irrecoverable loss spikes or perform any rollbacks.
|
@ -343,7 +343,7 @@ This code repository is licensed under [the MIT License](LICENSE-CODE). The use
|
||||
```
|
||||
@misc{deepseekai2024deepseekv3technicalreport,
|
||||
title={DeepSeek-V3 Technical Report},
|
||||
author={DeepSeek-AI and Aixin Liu and Bei Feng and Bing Xue and Bingxuan Wang and Bochao Wu and Chengda Lu and Chenggang Zhao and Chengqi Deng and Chenyu Zhang and Chong Ruan and Damai Dai and Daya Guo and Dejian Yang and Deli Chen and Dongjie Ji and Erhang Li and Fangyun Lin and Fucong Dai and Fuli Luo and Guangbo Hao and Guanting Chen and Guowei Li and H. Zhang and Han Bao and Hanwei Xu and Haocheng Wang and Haowei Zhang and Honghui Ding and Huajian Xin and Huazuo Gao and Hui Li and Hui Qu and J. L. Cai and Jian Liang and Jianzhong Guo and Jiaqi Ni and Jiashi Li and Jiawei Wang and Jin Chen and Jingchang Chen and Jingyang Yuan and Junjie Qiu and Junlong Li and Junxiao Song and Kai Dong and Kai Hu and Kaige Gao and Kang Guan and Kexin Huang and Kuai Yu and Lean Wang and Lecong Zhang and Lei Xu and Leyi Xia and Liang Zhao and Litong Wang and Liyue Zhang and Meng Li and Miaojun Wang and Mingchuan Zhang and Minghua Zhang and Minghui Tang and Mingming Li and Ning Tian and Panpan Huang and Peiyi Wang and Peng Zhang and Qiancheng Wang and Qihao Zhu and Qinyu Chen and Qiushi Du and R. J. Chen and R. L. Jin and Ruiqi Ge and Ruisong Zhang and Ruizhe Pan and Runji Wang and Runxin Xu and Ruoyu Zhang and Ruyi Chen and S. S. Li and Shanghao Lu and Shangyan Zhou and Shanhuang Chen and Shaoqing Wu and Shengfeng Ye and Shengfeng Ye and Shirong Ma and Shiyu Wang and Shuang Zhou and Shuiping Yu and Shunfeng Zhou and Shuting Pan and T. Wang and Tao Yun and Tian Pei and Tianyu Sun and W. L. Xiao and Wangding Zeng and Wanjia Zhao and Wei An and Wen Liu and Wenfeng Liang and Wenjun Gao and Wenqin Yu and Wentao Zhang and X. Q. Li and Xiangyue Jin and Xianzu Wang and Xiao Bi and Xiaodong Liu and Xiaohan Wang and Xiaojin Shen and Xiaokang Chen and Xiaokang Zhang and Xiaosha Chen and Xiaotao Nie and Xiaowen Sun and Xiaoxiang Wang and Xin Cheng and Xin Liu and Xin Xie and Xingchao Liu and Xingkai Yu and Xinnan Song and Xinxia Shan and Xinyi Zhou and Xinyu Yang and Xinyuan Li and Xuecheng Su and Xuheng Lin and Y. K. Li and Y. Q. Wang and Y. X. Wei and Y. X. Zhu and Yang Zhang and Yanhong Xu and Yanhong Xu and Yanping Huang and Yao Li and Yao Zhao and Yaofeng Sun and Yaohui Li and Yaohui Wang and Yi Yu and Yi Zheng and Yichao Zhang and Yifan Shi and Yiliang Xiong and Ying He and Ying Tang and Yishi Piao and Yisong Wang and Yixuan Tan and Yiyang Ma and Yiyuan Liu and Yongqiang Guo and Yu Wu and Yuan Ou and Yuchen Zhu and Yuduan Wang and Yue Gong and Yuheng Zou and Yujia He and Yukun Zha and Yunfan Xiong and Yunxian Ma and Yuting Yan and Yuxiang Luo and Yuxiang You and Yuxuan Liu and Yuyang Zhou and Z. F. Wu and Z. Z. Ren and Zehui Ren and Zhangli Sha and Zhe Fu and Zhean Xu and Zhen Huang and Zhen Zhang and Zhenda Xie and Zhengyan Zhang and Zhewen Hao and Zhibin Gou and Zhicheng Ma and Zhigang Yan and Zhihong Shao and Zhipeng Xu and Zhiyu Wu and Zhongyu Zhang and Zhuoshu Li and Zihui Gu and Zijia Zhu and Zijun Liu and Zilin Li and Ziwei Xie and Ziyang Song and Ziyi Gao and Zizheng Pan},
|
||||
author={DeepSeek-AI},
|
||||
year={2024},
|
||||
eprint={2412.19437},
|
||||
archivePrefix={arXiv},
|
||||
|
@ -2,13 +2,19 @@ import os
|
||||
import shutil
|
||||
from argparse import ArgumentParser
|
||||
from glob import glob
|
||||
from tqdm import tqdm, trange
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import safe_open, save_file
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
# Constants and type definitions
|
||||
TensorMapping = Dict[str, Tuple[str, Optional[int]]]
|
||||
StateDict = Dict[str, torch.Tensor]
|
||||
|
||||
mapping = {
|
||||
# Define mapping as a constant at module level
|
||||
TENSOR_MAPPING: TensorMapping = {
|
||||
"embed_tokens": ("embed", 0),
|
||||
"input_layernorm": ("attn_norm", None),
|
||||
"post_attention_layernorm": ("ffn_norm", None),
|
||||
@ -29,68 +35,144 @@ mapping = {
|
||||
"scale": ("scale", None),
|
||||
}
|
||||
|
||||
|
||||
def main(hf_ckpt_path, save_path, n_experts, mp):
|
||||
def process_tensor_name(name: str) -> str:
|
||||
"""
|
||||
Converts and saves model checkpoint files into a specified format.
|
||||
Process tensor name by removing prefixes and replacing common patterns.
|
||||
|
||||
Args:
|
||||
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
|
||||
save_path (str): Path to the directory where the converted checkpoint files will be saved.
|
||||
n_experts (int): Total number of experts in the model.
|
||||
mp (int): Model parallelism factor.
|
||||
name: Original tensor name
|
||||
|
||||
Returns:
|
||||
None
|
||||
Processed tensor name
|
||||
"""
|
||||
if name.startswith("model."):
|
||||
name = name[len("model."):]
|
||||
|
||||
replacements = {
|
||||
"self_attn": "attn",
|
||||
"mlp": "ffn",
|
||||
"weight_scale_inv": "scale",
|
||||
"e_score_correction_bias": "bias"
|
||||
}
|
||||
|
||||
for old, new in replacements.items():
|
||||
name = name.replace(old, new)
|
||||
|
||||
return name
|
||||
|
||||
def shard_tensor(param: torch.Tensor, mp_idx: int, mp_count: int, dim: int) -> torch.Tensor:
|
||||
"""
|
||||
Shard a tensor along specified dimension for model parallelism.
|
||||
|
||||
Args:
|
||||
param: Input tensor to shard
|
||||
mp_idx: Index of current model parallel rank
|
||||
mp_count: Total number of model parallel ranks
|
||||
dim: Dimension along which to shard
|
||||
|
||||
Returns:
|
||||
Sharded tensor slice
|
||||
"""
|
||||
if param.size(dim) % mp_count != 0:
|
||||
raise ValueError(f"Tensor size {param.size(dim)} not divisible by mp_count {mp_count}")
|
||||
|
||||
shard_size = param.size(dim) // mp_count
|
||||
return param.narrow(dim, mp_idx * shard_size, shard_size).contiguous()
|
||||
|
||||
def convert_checkpoint(
|
||||
hf_ckpt_path: Union[str, Path],
|
||||
save_path: Union[str, Path],
|
||||
n_experts: int,
|
||||
mp: int
|
||||
) -> None:
|
||||
"""
|
||||
Convert and save model checkpoint files into a specified format.
|
||||
|
||||
Args:
|
||||
hf_ckpt_path: Path to input checkpoint directory
|
||||
save_path: Path to output directory for converted checkpoints
|
||||
n_experts: Total number of experts in model
|
||||
mp: Model parallelism factor
|
||||
|
||||
Raises:
|
||||
ValueError: If n_experts is not divisible by mp
|
||||
FileNotFoundError: If input path doesn't exist or contain safetensors
|
||||
"""
|
||||
if n_experts % mp != 0:
|
||||
raise ValueError(f"Number of experts ({n_experts}) must be divisible by model parallel size ({mp})")
|
||||
|
||||
hf_ckpt_path = Path(hf_ckpt_path)
|
||||
save_path = Path(save_path)
|
||||
|
||||
if not hf_ckpt_path.exists():
|
||||
raise FileNotFoundError(f"Checkpoint path {hf_ckpt_path} does not exist")
|
||||
|
||||
safetensor_files = list(hf_ckpt_path.glob("*.safetensors"))
|
||||
if not safetensor_files:
|
||||
raise FileNotFoundError(f"No safetensor files found in {hf_ckpt_path}")
|
||||
|
||||
torch.set_num_threads(8)
|
||||
n_local_experts = n_experts // mp
|
||||
state_dicts = [{} for _ in range(mp)]
|
||||
state_dicts: List[StateDict] = [{} for _ in range(mp)]
|
||||
|
||||
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
|
||||
# Process each checkpoint file
|
||||
for file_path in tqdm(safetensor_files, desc="Processing checkpoint files"):
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
for name in f.keys():
|
||||
if "model.layers.61" in name:
|
||||
continue
|
||||
|
||||
param: torch.Tensor = f.get_tensor(name)
|
||||
if name.startswith("model."):
|
||||
name = name[len("model."):]
|
||||
name = name.replace("self_attn", "attn")
|
||||
name = name.replace("mlp", "ffn")
|
||||
name = name.replace("weight_scale_inv", "scale")
|
||||
name = name.replace("e_score_correction_bias", "bias")
|
||||
name = process_tensor_name(name)
|
||||
|
||||
key = name.split(".")[-2]
|
||||
assert key in mapping, f"Key {key} not found in mapping"
|
||||
new_key, dim = mapping[key]
|
||||
if key not in TENSOR_MAPPING:
|
||||
raise ValueError(f"Unknown tensor key: {key}")
|
||||
|
||||
new_key, dim = TENSOR_MAPPING[key]
|
||||
name = name.replace(key, new_key)
|
||||
|
||||
# Distribute tensors across model parallel ranks
|
||||
for i in range(mp):
|
||||
new_param = param
|
||||
if "experts" in name and "shared_experts" not in name:
|
||||
idx = int(name.split(".")[-3])
|
||||
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
||||
if not (i * n_local_experts <= idx < (i + 1) * n_local_experts):
|
||||
continue
|
||||
elif dim is not None:
|
||||
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
|
||||
shard_size = param.size(dim) // mp
|
||||
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
||||
new_param = shard_tensor(param, i, mp, dim)
|
||||
state_dicts[i][name] = new_param
|
||||
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
# Save converted checkpoints
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in trange(mp):
|
||||
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
||||
for i in trange(mp, desc="Saving converted checkpoints"):
|
||||
output_file = save_path / f"model{i}-mp{mp}.safetensors"
|
||||
save_file(state_dicts[i], str(output_file))
|
||||
|
||||
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
|
||||
new_file_path = os.path.join(save_path, os.path.basename(file_path))
|
||||
shutil.copyfile(file_path, new_file_path)
|
||||
# Copy tokenizer files
|
||||
for file_path in hf_ckpt_path.glob("*token*"):
|
||||
shutil.copyfile(file_path, save_path / file_path.name)
|
||||
|
||||
def main():
|
||||
"""Parse command line arguments and run the conversion."""
|
||||
parser = ArgumentParser(description="Convert HuggingFace checkpoints to custom format")
|
||||
parser.add_argument("--hf-ckpt-path", type=str, required=True,
|
||||
help="Path to input HuggingFace checkpoint directory")
|
||||
parser.add_argument("--save-path", type=str, required=True,
|
||||
help="Path to output directory for converted checkpoints")
|
||||
parser.add_argument("--n-experts", type=int, required=True,
|
||||
help="Total number of experts in the model")
|
||||
parser.add_argument("--model-parallel", type=int, required=True,
|
||||
help="Model parallelism factor")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
convert_checkpoint(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
|
||||
except Exception as e:
|
||||
print(f"Error during conversion: {str(e)}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--hf-ckpt-path", type=str, required=True)
|
||||
parser.add_argument("--save-path", type=str, required=True)
|
||||
parser.add_argument("--n-experts", type=int, required=True)
|
||||
parser.add_argument("--model-parallel", type=int, required=True)
|
||||
args = parser.parse_args()
|
||||
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
|
||||
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
|
||||
main()
|
@ -2,6 +2,7 @@ import os
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
from glob import glob
|
||||
from typing import Dict, Any
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
@ -9,98 +10,137 @@ from safetensors.torch import load_file, save_file
|
||||
|
||||
from kernel import weight_dequant
|
||||
|
||||
def main(fp8_path, bf16_path):
|
||||
"""
|
||||
Converts FP8 weights to BF16 and saves the converted weights.
|
||||
|
||||
This function reads FP8 weights from the specified directory, converts them to BF16,
|
||||
and saves the converted weights to another specified directory. It also updates the
|
||||
model index file to reflect the changes.
|
||||
class WeightConverter:
|
||||
def __init__(self, fp8_path: str, bf16_path: str):
|
||||
"""
|
||||
Initialize the weight converter with input and output paths.
|
||||
|
||||
Args:
|
||||
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
|
||||
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
|
||||
|
||||
Raises:
|
||||
KeyError: If a required scale_inv tensor is missing for a weight.
|
||||
|
||||
Notes:
|
||||
- The function assumes that the FP8 weights are stored in safetensor files.
|
||||
- The function caches loaded safetensor files to optimize memory usage.
|
||||
- The function updates the model index file to remove references to scale_inv tensors.
|
||||
fp8_path (str): Path to the directory containing FP8 weights
|
||||
bf16_path (str): Path to save the converted BF16 weights
|
||||
"""
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
os.makedirs(bf16_path, exist_ok=True)
|
||||
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
||||
with open(model_index_file, "r") as f:
|
||||
model_index = json.load(f)
|
||||
weight_map = model_index["weight_map"]
|
||||
self.fp8_path = fp8_path
|
||||
self.bf16_path = bf16_path
|
||||
self.loaded_files: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
self.fp8_weight_names: list = []
|
||||
self.weight_map: Dict[str, str] = self._load_model_index()
|
||||
|
||||
# Cache for loaded safetensor files
|
||||
loaded_files = {}
|
||||
fp8_weight_names = []
|
||||
|
||||
# Helper function to get tensor from the correct file
|
||||
def get_tensor(tensor_name):
|
||||
def _load_model_index(self) -> Dict[str, str]:
|
||||
"""
|
||||
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
|
||||
|
||||
Args:
|
||||
tensor_name (str): The name of the tensor to retrieve.
|
||||
Load the model index file.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The retrieved tensor.
|
||||
Dict[str, str]: Weight mapping from the index file
|
||||
"""
|
||||
model_index_file = os.path.join(self.fp8_path, "model.safetensors.index.json")
|
||||
with open(model_index_file, "r") as f:
|
||||
return json.load(f)["weight_map"]
|
||||
|
||||
def _get_tensor(self, tensor_name: str) -> torch.Tensor:
|
||||
"""
|
||||
Get a tensor from cache or load it from disk.
|
||||
|
||||
Args:
|
||||
tensor_name (str): Name of the tensor to retrieve
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The requested tensor
|
||||
|
||||
Raises:
|
||||
KeyError: If the tensor does not exist in the safetensor file.
|
||||
KeyError: If tensor doesn't exist in the safetensor file
|
||||
"""
|
||||
file_name = weight_map[tensor_name]
|
||||
if file_name not in loaded_files:
|
||||
file_path = os.path.join(fp8_path, file_name)
|
||||
loaded_files[file_name] = load_file(file_path, device="cuda")
|
||||
return loaded_files[file_name][tensor_name]
|
||||
file_name = self.weight_map[tensor_name]
|
||||
if file_name not in self.loaded_files:
|
||||
file_path = os.path.join(self.fp8_path, file_name)
|
||||
self.loaded_files[file_name] = load_file(file_path, device="cuda")
|
||||
return self.loaded_files[file_name][tensor_name]
|
||||
|
||||
def _manage_memory(self):
|
||||
"""
|
||||
Keep only the 2 most recently used files in memory.
|
||||
"""
|
||||
if len(self.loaded_files) > 2:
|
||||
oldest_file = next(iter(self.loaded_files))
|
||||
del self.loaded_files[oldest_file]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _process_weight(self, weight_name: str, weight: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Process a single weight tensor.
|
||||
|
||||
Args:
|
||||
weight_name (str): Name of the weight tensor
|
||||
weight (torch.Tensor): The weight tensor to process
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Processed weight tensor
|
||||
"""
|
||||
if weight_name.endswith("_scale_inv"):
|
||||
return None
|
||||
|
||||
if weight.element_size() == 1: # FP8 weight
|
||||
scale_inv_name = f"{weight_name}_scale_inv"
|
||||
try:
|
||||
scale_inv = self._get_tensor(scale_inv_name)
|
||||
self.fp8_weight_names.append(weight_name)
|
||||
return weight_dequant(weight, scale_inv)
|
||||
except KeyError:
|
||||
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
|
||||
return weight
|
||||
return weight
|
||||
|
||||
def _save_model_index(self):
|
||||
"""
|
||||
Save the updated model index file.
|
||||
"""
|
||||
new_model_index_file = os.path.join(self.bf16_path, "model.safetensors.index.json")
|
||||
for weight_name in self.fp8_weight_names:
|
||||
scale_inv_name = f"{weight_name}_scale_inv"
|
||||
if scale_inv_name in self.weight_map:
|
||||
self.weight_map.pop(scale_inv_name)
|
||||
|
||||
with open(new_model_index_file, "w") as f:
|
||||
json.dump({"metadata": {}, "weight_map": self.weight_map}, f, indent=2)
|
||||
|
||||
def convert(self):
|
||||
"""
|
||||
Convert FP8 weights to BF16 format.
|
||||
"""
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
os.makedirs(self.bf16_path, exist_ok=True)
|
||||
|
||||
safetensor_files = sorted(glob(os.path.join(self.fp8_path, "*.safetensors")))
|
||||
|
||||
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
|
||||
safetensor_files.sort()
|
||||
for safetensor_file in tqdm(safetensor_files):
|
||||
file_name = os.path.basename(safetensor_file)
|
||||
current_state_dict = load_file(safetensor_file, device="cuda")
|
||||
loaded_files[file_name] = current_state_dict
|
||||
self.loaded_files[file_name] = current_state_dict
|
||||
|
||||
new_state_dict = {}
|
||||
for weight_name, weight in current_state_dict.items():
|
||||
if weight_name.endswith("_scale_inv"):
|
||||
continue
|
||||
elif weight.element_size() == 1: # FP8 weight
|
||||
scale_inv_name = f"{weight_name}_scale_inv"
|
||||
try:
|
||||
# Get scale_inv from the correct file
|
||||
scale_inv = get_tensor(scale_inv_name)
|
||||
fp8_weight_names.append(weight_name)
|
||||
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
|
||||
except KeyError:
|
||||
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
|
||||
new_state_dict[weight_name] = weight
|
||||
else:
|
||||
new_state_dict[weight_name] = weight
|
||||
processed_weight = self._process_weight(weight_name, weight)
|
||||
if processed_weight is not None:
|
||||
new_state_dict[weight_name] = processed_weight
|
||||
|
||||
new_safetensor_file = os.path.join(bf16_path, file_name)
|
||||
new_safetensor_file = os.path.join(self.bf16_path, file_name)
|
||||
save_file(new_state_dict, new_safetensor_file)
|
||||
|
||||
# Memory management: keep only the 2 most recently used files
|
||||
if len(loaded_files) > 2:
|
||||
oldest_file = next(iter(loaded_files))
|
||||
del loaded_files[oldest_file]
|
||||
torch.cuda.empty_cache()
|
||||
self._manage_memory()
|
||||
|
||||
# Update model index
|
||||
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
|
||||
for weight_name in fp8_weight_names:
|
||||
scale_inv_name = f"{weight_name}_scale_inv"
|
||||
if scale_inv_name in weight_map:
|
||||
weight_map.pop(scale_inv_name)
|
||||
with open(new_model_index_file, "w") as f:
|
||||
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
|
||||
self._save_model_index()
|
||||
|
||||
|
||||
def main(fp8_path: str, bf16_path: str):
|
||||
"""
|
||||
Main function to convert FP8 weights to BF16.
|
||||
|
||||
Args:
|
||||
fp8_path (str): Input directory containing FP8 weights
|
||||
bf16_path (str): Output directory for BF16 weights
|
||||
"""
|
||||
converter = WeightConverter(fp8_path, bf16_path)
|
||||
converter.convert()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -109,4 +149,3 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
import os
|
||||
import json
|
||||
from argparse import ArgumentParser
|
||||
from typing import List
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -11,13 +12,22 @@ from safetensors.torch import load_model
|
||||
from model import Transformer, ModelArgs
|
||||
|
||||
|
||||
def sample(logits, temperature: float = 1.0):
|
||||
@dataclass
|
||||
class GenerationConfig:
|
||||
max_new_tokens: int
|
||||
temperature: float
|
||||
eos_id: int
|
||||
|
||||
|
||||
class TokenSampler:
|
||||
@staticmethod
|
||||
def sample(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
|
||||
"""
|
||||
Samples a token from the logits using temperature scaling.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): The logits tensor for token predictions.
|
||||
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
|
||||
temperature (float): Temperature for scaling logits.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The sampled token.
|
||||
@ -27,48 +37,94 @@ def sample(logits, temperature: float = 1.0):
|
||||
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
||||
|
||||
|
||||
class TextGenerator:
|
||||
def __init__(self, model: Transformer, tokenizer: Any):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
model: Transformer,
|
||||
self,
|
||||
prompt_tokens: List[List[int]],
|
||||
max_new_tokens: int,
|
||||
eos_id: int,
|
||||
temperature: float = 1.0
|
||||
config: GenerationConfig
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Generates new tokens based on the given prompt tokens using the specified model.
|
||||
Generates new tokens based on the given prompt tokens.
|
||||
|
||||
Args:
|
||||
model (Transformer): The transformer model used for token generation.
|
||||
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
|
||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||
eos_id (int): The end-of-sequence token ID.
|
||||
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
|
||||
prompt_tokens: A list of lists containing the prompt tokens for each sequence.
|
||||
config: Generation configuration parameters.
|
||||
|
||||
Returns:
|
||||
List[List[int]]: A list of lists containing the generated tokens for each sequence.
|
||||
List[List[int]]: Generated tokens for each sequence.
|
||||
"""
|
||||
prompt_lens = [len(t) for t in prompt_tokens]
|
||||
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
|
||||
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
||||
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
|
||||
if max(prompt_lens) > self.model.max_seq_len:
|
||||
raise ValueError(f"Prompt length exceeds model maximum sequence length (max_seq_len={self.model.max_seq_len})")
|
||||
|
||||
total_len = min(self.model.max_seq_len, config.max_new_tokens + max(prompt_lens))
|
||||
tokens = self._initialize_tokens(prompt_tokens, total_len)
|
||||
|
||||
completion_tokens = self._generate_tokens(
|
||||
tokens, prompt_lens, total_len, config
|
||||
)
|
||||
return completion_tokens
|
||||
|
||||
def _initialize_tokens(
|
||||
self, prompt_tokens: List[List[int]], total_len: int
|
||||
) -> torch.Tensor:
|
||||
tokens = torch.full(
|
||||
(len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda"
|
||||
)
|
||||
for i, t in enumerate(prompt_tokens):
|
||||
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||
return tokens
|
||||
|
||||
def _generate_tokens(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
prompt_lens: List[int],
|
||||
total_len: int,
|
||||
config: GenerationConfig
|
||||
) -> List[List[int]]:
|
||||
prev_pos = 0
|
||||
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
|
||||
finished = torch.tensor([False] * len(prompt_lens), device="cuda")
|
||||
prompt_mask = tokens != -1
|
||||
|
||||
for cur_pos in range(min(prompt_lens), total_len):
|
||||
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
if temperature > 0:
|
||||
next_token = sample(logits, temperature)
|
||||
else:
|
||||
next_token = logits.argmax(dim=-1)
|
||||
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
next_token = self._get_next_token(logits, config.temperature)
|
||||
next_token = torch.where(
|
||||
prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
||||
)
|
||||
|
||||
tokens[:, cur_pos] = next_token
|
||||
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
||||
finished |= torch.logical_and(
|
||||
~prompt_mask[:, cur_pos], next_token == config.eos_id
|
||||
)
|
||||
prev_pos = cur_pos
|
||||
|
||||
if finished.all():
|
||||
break
|
||||
|
||||
return self._process_completion_tokens(
|
||||
tokens, prompt_lens, config.max_new_tokens, config.eos_id
|
||||
)
|
||||
|
||||
def _get_next_token(
|
||||
self, logits: torch.Tensor, temperature: float
|
||||
) -> torch.Tensor:
|
||||
if temperature > 0:
|
||||
return TokenSampler.sample(logits, temperature)
|
||||
return logits.argmax(dim=-1)
|
||||
|
||||
def _process_completion_tokens(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
prompt_lens: List[int],
|
||||
max_new_tokens: int,
|
||||
eos_id: int
|
||||
) -> List[List[int]]:
|
||||
completion_tokens = []
|
||||
for i, toks in enumerate(tokens.tolist()):
|
||||
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
|
||||
@ -78,6 +134,138 @@ def generate(
|
||||
return completion_tokens
|
||||
|
||||
|
||||
class DistributedEnvironment:
|
||||
def __init__(self):
|
||||
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
self.rank = int(os.getenv("RANK", "0"))
|
||||
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
|
||||
def setup(self):
|
||||
if self.world_size > 1:
|
||||
dist.init_process_group("nccl")
|
||||
if self.rank != 0:
|
||||
global print
|
||||
print = lambda *_, **__: None
|
||||
torch.cuda.set_device(self.local_rank)
|
||||
|
||||
def cleanup(self):
|
||||
if self.world_size > 1:
|
||||
dist.destroy_process_group()
|
||||
|
||||
def broadcast_prompt(self, prompt: Optional[str] = None) -> str:
|
||||
if self.world_size == 1:
|
||||
return input(">>> ")
|
||||
elif self.rank == 0:
|
||||
prompt = input(">>> ")
|
||||
objects = [prompt]
|
||||
dist.broadcast_object_list(objects, 0)
|
||||
return prompt
|
||||
else:
|
||||
objects = [None]
|
||||
dist.broadcast_object_list(objects, 0)
|
||||
return objects[0]
|
||||
|
||||
|
||||
class ChatSession:
|
||||
def __init__(
|
||||
self,
|
||||
generator: TextGenerator,
|
||||
config: GenerationConfig,
|
||||
dist_env: DistributedEnvironment
|
||||
):
|
||||
self.generator = generator
|
||||
self.config = config
|
||||
self.dist_env = dist_env
|
||||
self.messages = []
|
||||
|
||||
def run_interactive(self):
|
||||
while True:
|
||||
prompt = self.dist_env.broadcast_prompt()
|
||||
if prompt == "/exit":
|
||||
break
|
||||
elif prompt == "/clear":
|
||||
self.messages.clear()
|
||||
continue
|
||||
|
||||
completion = self._process_message(prompt)
|
||||
print(completion)
|
||||
self.messages.append({"role": "assistant", "content": completion})
|
||||
|
||||
def run_batch(self, input_file: str):
|
||||
with open(input_file) as f:
|
||||
prompts = [line.strip() for line in f.readlines()]
|
||||
|
||||
if len(prompts) > self.generator.model.args.max_batch_size:
|
||||
raise ValueError(f"Number of prompts exceeds maximum batch size ({self.generator.model.args.max_batch_size})")
|
||||
|
||||
completions = self._process_batch(prompts)
|
||||
for prompt, completion in zip(prompts, completions):
|
||||
print("Prompt:", prompt)
|
||||
print("Completion:", completion)
|
||||
print()
|
||||
|
||||
def _process_message(self, prompt: str) -> str:
|
||||
self.messages.append({"role": "user", "content": prompt})
|
||||
prompt_tokens = self.generator.tokenizer.apply_chat_template(
|
||||
self.messages, add_generation_prompt=True
|
||||
)
|
||||
completion_tokens = self.generator.generate(
|
||||
[prompt_tokens], self.config
|
||||
)
|
||||
return self.generator.tokenizer.decode(
|
||||
completion_tokens[0], skip_special_tokens=True
|
||||
)
|
||||
|
||||
def _process_batch(self, prompts: List[str]) -> List[str]:
|
||||
prompt_tokens = [
|
||||
self.generator.tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
add_generation_prompt=True
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
completion_tokens = self.generator.generate(
|
||||
prompt_tokens, self.config
|
||||
)
|
||||
return self.generator.tokenizer.batch_decode(
|
||||
completion_tokens, skip_special_tokens=True
|
||||
)
|
||||
|
||||
|
||||
def initialize_model(
|
||||
ckpt_path: str, config_path: str, dist_env: DistributedEnvironment
|
||||
) -> Tuple[Transformer, Any]:
|
||||
"""Initialize the model and tokenizer."""
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_num_threads(8)
|
||||
torch.manual_seed(965)
|
||||
|
||||
with open(config_path) as f:
|
||||
args = ModelArgs(**json.load(f))
|
||||
print(args)
|
||||
|
||||
with torch.device("cuda"):
|
||||
model = Transformer(args)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
||||
|
||||
# Warmup
|
||||
tokenizer.decode(
|
||||
TextGenerator(model, tokenizer).generate(
|
||||
[tokenizer.encode("DeepSeek")],
|
||||
GenerationConfig(max_new_tokens=2, temperature=1.0, eos_id=-1)
|
||||
)[0]
|
||||
)
|
||||
|
||||
load_model(
|
||||
model,
|
||||
os.path.join(
|
||||
ckpt_path,
|
||||
f"model{dist_env.rank}-mp{dist_env.world_size}.safetensors"
|
||||
)
|
||||
)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def main(
|
||||
ckpt_path: str,
|
||||
config: str,
|
||||
@ -86,94 +274,29 @@ def main(
|
||||
max_new_tokens: int = 100,
|
||||
temperature: float = 1.0,
|
||||
) -> None:
|
||||
"""
|
||||
Main function to load the model and perform interactive or batch text generation.
|
||||
dist_env = DistributedEnvironment()
|
||||
dist_env.setup()
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the model checkpoint directory.
|
||||
config (str): Path to the model configuration file.
|
||||
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
|
||||
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
|
||||
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
|
||||
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
|
||||
"""
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
rank = int(os.getenv("RANK", "0"))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
if world_size > 1:
|
||||
dist.init_process_group("nccl")
|
||||
global print
|
||||
if rank != 0:
|
||||
print = lambda *_, **__: None
|
||||
torch.cuda.set_device(local_rank)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_num_threads(8)
|
||||
torch.manual_seed(965)
|
||||
with open(config) as f:
|
||||
args = ModelArgs(**json.load(f))
|
||||
print(args)
|
||||
with torch.device("cuda"):
|
||||
model = Transformer(args)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
||||
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
|
||||
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
||||
model, tokenizer = initialize_model(ckpt_path, config, dist_env)
|
||||
generator = TextGenerator(model, tokenizer)
|
||||
gen_config = GenerationConfig(
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
eos_id=tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
session = ChatSession(generator, gen_config, dist_env)
|
||||
|
||||
if interactive:
|
||||
messages = []
|
||||
while True:
|
||||
if world_size == 1:
|
||||
prompt = input(">>> ")
|
||||
elif rank == 0:
|
||||
prompt = input(">>> ")
|
||||
objects = [prompt]
|
||||
dist.broadcast_object_list(objects, 0)
|
||||
session.run_interactive()
|
||||
else:
|
||||
objects = [None]
|
||||
dist.broadcast_object_list(objects, 0)
|
||||
prompt = objects[0]
|
||||
if prompt == "/exit":
|
||||
break
|
||||
elif prompt == "/clear":
|
||||
messages.clear()
|
||||
continue
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
||||
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
|
||||
print(completion)
|
||||
messages.append({"role": "assistant", "content": completion})
|
||||
else:
|
||||
with open(input_file) as f:
|
||||
prompts = [line.strip() for line in f.readlines()]
|
||||
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
|
||||
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
|
||||
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
||||
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
|
||||
for prompt, completion in zip(prompts, completions):
|
||||
print("Prompt:", prompt)
|
||||
print("Completion:", completion)
|
||||
print()
|
||||
session.run_batch(input_file)
|
||||
|
||||
if world_size > 1:
|
||||
dist.destroy_process_group()
|
||||
dist_env.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Command-line interface for distributed text generation.
|
||||
|
||||
Arguments:
|
||||
--ckpt-path (str): Path to the model checkpoint directory.
|
||||
--config (str): Path to the model configuration file.
|
||||
--input-file (str, optional): File containing prompts for batch processing.
|
||||
--interactive (bool, optional): Enable interactive mode for generating text.
|
||||
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
|
||||
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
|
||||
|
||||
Raises:
|
||||
AssertionError: If neither input-file nor interactive mode is specified.
|
||||
"""
|
||||
parser = ArgumentParser()
|
||||
parser = ArgumentParser(description="Distributed text generation system")
|
||||
parser.add_argument("--ckpt-path", type=str, required=True)
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
parser.add_argument("--input-file", type=str, default="")
|
||||
@ -181,5 +304,15 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--max-new-tokens", type=int, default=200)
|
||||
parser.add_argument("--temperature", type=float, default=0.2)
|
||||
args = parser.parse_args()
|
||||
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
|
||||
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
||||
|
||||
if not args.input_file and not args.interactive:
|
||||
raise ValueError("Either input-file or interactive mode must be specified")
|
||||
|
||||
main(
|
||||
args.ckpt_path,
|
||||
args.config,
|
||||
args.input_file,
|
||||
args.interactive,
|
||||
args.max_new_tokens,
|
||||
args.temperature
|
||||
)
|
@ -1,4 +1,5 @@
|
||||
from typing import Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@ -6,19 +7,29 @@ import triton.language as tl
|
||||
from triton import Config
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockConfig:
|
||||
"""Configuration for block sizes in tensor operations."""
|
||||
size: int = 128
|
||||
size_m: int = 64
|
||||
size_n: int = 64
|
||||
size_k: int = 128
|
||||
|
||||
|
||||
class QuantizationKernels:
|
||||
"""Collection of Triton kernels for quantization operations."""
|
||||
|
||||
@staticmethod
|
||||
@triton.jit
|
||||
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
"""
|
||||
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
||||
Quantizes activation values using block-wise scaling.
|
||||
|
||||
Args:
|
||||
x_ptr (triton.Pointer): Pointer to the input tensor.
|
||||
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
|
||||
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
|
||||
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
|
||||
|
||||
Returns:
|
||||
None
|
||||
x_ptr: Input tensor pointer
|
||||
y_ptr: Output quantized tensor pointer
|
||||
s_ptr: Output scaling factors pointer
|
||||
BLOCK_SIZE: Size of processing block
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
@ -29,44 +40,19 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
tl.store(y_ptr + offs, y)
|
||||
tl.store(s_ptr + pid, s)
|
||||
|
||||
|
||||
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantizes the input tensor `x` using block-wise quantization.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
||||
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
||||
- A tensor of scaling factors with dtype `torch.float32`.
|
||||
"""
|
||||
assert x.is_contiguous(), 'Input tensor must be contiguous'
|
||||
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
|
||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
||||
return y, s
|
||||
|
||||
|
||||
@staticmethod
|
||||
@triton.jit
|
||||
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||
"""
|
||||
Dequantizes weights using the provided scaling factors and stores the result.
|
||||
Dequantizes weights using block-wise scaling.
|
||||
|
||||
Args:
|
||||
x_ptr (tl.pointer): Pointer to the quantized weights.
|
||||
s_ptr (tl.pointer): Pointer to the scaling factors.
|
||||
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
|
||||
M (int): Number of rows in the weight matrix.
|
||||
N (int): Number of columns in the weight matrix.
|
||||
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
|
||||
|
||||
Returns:
|
||||
None
|
||||
x_ptr: Quantized weights pointer
|
||||
s_ptr: Scaling factors pointer
|
||||
y_ptr: Output dequantized tensor pointer
|
||||
M: Number of rows
|
||||
N: Number of columns
|
||||
BLOCK_SIZE: Size of processing block
|
||||
"""
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_n = tl.program_id(axis=1)
|
||||
@ -81,84 +67,80 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||
tl.store(y_ptr + offs, y, mask=mask)
|
||||
|
||||
|
||||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||
"""
|
||||
Dequantizes the given weight tensor using the provided scale tensor.
|
||||
class MatrixMultKernels:
|
||||
"""Collection of Triton kernels for matrix multiplication operations."""
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The quantized weight tensor of shape (M, N).
|
||||
s (torch.Tensor): The scale tensor of shape (M, N).
|
||||
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
|
||||
"""
|
||||
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
|
||||
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
|
||||
M, N = x.size()
|
||||
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
||||
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
||||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
||||
return y
|
||||
|
||||
|
||||
fp8_gemm_configs = [
|
||||
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
|
||||
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
|
||||
@staticmethod
|
||||
def get_configs():
|
||||
"""Generate configurations for FP8 GEMM autotuning."""
|
||||
return [
|
||||
Config({
|
||||
'BLOCK_SIZE_M': block_m,
|
||||
'BLOCK_SIZE_N': block_n,
|
||||
'BLOCK_SIZE_K': 128
|
||||
}, num_stages=num_stages, num_warps=8)
|
||||
for block_m in [16, 32, 64]
|
||||
for block_n in [32, 64, 128]
|
||||
for num_stages in [3, 4, 5, 6]
|
||||
]
|
||||
|
||||
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
|
||||
@staticmethod
|
||||
@triton.autotune(configs=get_configs(), key=['N', 'K'])
|
||||
@triton.jit
|
||||
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
||||
def fp8_gemm_kernel(
|
||||
a_ptr, b_ptr, c_ptr,
|
||||
a_s_ptr, b_s_ptr,
|
||||
M, N: tl.constexpr, K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr):
|
||||
BLOCK_SIZE_K: tl.constexpr
|
||||
):
|
||||
"""
|
||||
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
||||
Performs FP8 matrix multiplication with scaling factors.
|
||||
|
||||
Args:
|
||||
a_ptr (tl.tensor): Pointer to the first input matrix A.
|
||||
b_ptr (tl.tensor): Pointer to the second input matrix B.
|
||||
c_ptr (tl.tensor): Pointer to the output matrix C.
|
||||
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
|
||||
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
|
||||
M (int): Number of rows in matrix A and C.
|
||||
N (tl.constexpr): Number of columns in matrix B and C.
|
||||
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
|
||||
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
|
||||
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
|
||||
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
|
||||
|
||||
Returns:
|
||||
None
|
||||
a_ptr: First input matrix pointer
|
||||
b_ptr: Second input matrix pointer
|
||||
c_ptr: Output matrix pointer
|
||||
a_s_ptr: First matrix scaling factors pointer
|
||||
b_s_ptr: Second matrix scaling factors pointer
|
||||
M: First matrix rows
|
||||
N: Second matrix columns
|
||||
K: Inner dimension
|
||||
BLOCK_SIZE_M/N/K: Block sizes for tiling
|
||||
"""
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_n = tl.program_id(axis=1)
|
||||
k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
|
||||
# Calculate offsets
|
||||
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
# Initialize pointers
|
||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
|
||||
a_s_ptrs = a_s_ptr + offs_m * k
|
||||
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# Main computation loop
|
||||
for i in range(k):
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
|
||||
a_s = tl.load(a_s_ptrs)
|
||||
b_s = tl.load(b_s_ptrs)
|
||||
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||
|
||||
# Update pointers
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += BLOCK_SIZE_K
|
||||
a_s_ptrs += 1
|
||||
b_s_ptrs += 1
|
||||
|
||||
# Store results
|
||||
c = accumulator.to(c_ptr.dtype.element_ty)
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
@ -167,25 +149,86 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
||||
tl.store(c_ptrs, c, mask=mask)
|
||||
|
||||
|
||||
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
|
||||
class TensorOps:
|
||||
"""High-level interface for tensor operations."""
|
||||
|
||||
@staticmethod
|
||||
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform a matrix multiplication using FP8 precision.
|
||||
Quantize activations using block-wise scaling.
|
||||
|
||||
Args:
|
||||
a (torch.Tensor): The first input matrix, must be contiguous.
|
||||
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
|
||||
b (torch.Tensor): The second input matrix, must be contiguous.
|
||||
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
|
||||
x: Input tensor
|
||||
block_size: Block size for quantization
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The result of the matrix multiplication.
|
||||
Tuple of quantized tensor and scaling factors
|
||||
"""
|
||||
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
|
||||
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
|
||||
assert x.is_contiguous()
|
||||
assert x.size(-1) % block_size == 0
|
||||
|
||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||
|
||||
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']),)
|
||||
QuantizationKernels.act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
||||
|
||||
return y, s
|
||||
|
||||
@staticmethod
|
||||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||
"""
|
||||
Dequantize weights using block-wise scaling.
|
||||
|
||||
Args:
|
||||
x: Quantized weight tensor
|
||||
s: Scaling factors tensor
|
||||
block_size: Block size for dequantization
|
||||
|
||||
Returns:
|
||||
Dequantized tensor
|
||||
"""
|
||||
assert x.is_contiguous() and s.is_contiguous()
|
||||
assert x.dim() == 2 and s.dim() == 2
|
||||
|
||||
M, N = x.size()
|
||||
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(M, meta['BLOCK_SIZE']),
|
||||
triton.cdiv(N, meta['BLOCK_SIZE'])
|
||||
)
|
||||
QuantizationKernels.weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
||||
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Perform FP8 matrix multiplication.
|
||||
|
||||
Args:
|
||||
a: First input matrix
|
||||
a_s: First matrix scaling factors
|
||||
b: Second input matrix
|
||||
b_s: Second matrix scaling factors
|
||||
|
||||
Returns:
|
||||
Result matrix
|
||||
"""
|
||||
assert a.is_contiguous() and b.is_contiguous()
|
||||
assert a_s.is_contiguous() and b_s.is_contiguous()
|
||||
|
||||
K = a.size(-1)
|
||||
M = a.numel() // K
|
||||
N = b.size(0)
|
||||
|
||||
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
|
||||
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META['BLOCK_SIZE_M']),
|
||||
triton.cdiv(N, META['BLOCK_SIZE_N'])
|
||||
)
|
||||
MatrixMultKernels.fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
||||
|
||||
return c
|
Loading…
x
Reference in New Issue
Block a user