Compare commits

...

5 Commits

Author SHA1 Message Date
Nripesh Niketan
ebc90f82ca
Merge a5336884cfa952634f355549643dcb781eaec872 into f09f5fa321f5a421704136c0463b1eaca6557712 2025-02-18 19:07:12 +08:00
Nripesh Niketan
a5336884cf fix: Update triton dependency to use the latest version from GitHub 2025-02-03 21:50:02 +00:00
Nripesh Niketan
73efe7c631 Memory management update 2025-02-02 10:41:20 +00:00
Nripesh Niketan
b6e3910fd0
Fix small error 2025-01-30 16:04:00 +00:00
Nripesh Niketan
e75ce46245 feat: Enhance device compatibility and update PyTorch version 2025-01-30 00:06:55 +00:00
4 changed files with 41 additions and 13 deletions

View File

@ -3,6 +3,7 @@ import json
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from tqdm import tqdm from tqdm import tqdm
import gc
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
@ -30,6 +31,12 @@ def main(fp8_path, bf16_path):
- The function updates the model index file to remove references to scale_inv tensors. - The function updates the model index file to remove references to scale_inv tensors.
""" """
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
if torch.cuda.is_available():
default_device = "cuda"
elif torch.mps.is_available():
default_device = "mps"
else:
default_device = "cpu"
os.makedirs(bf16_path, exist_ok=True) os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f: with open(model_index_file, "r") as f:
@ -57,14 +64,14 @@ def main(fp8_path, bf16_path):
file_name = weight_map[tensor_name] file_name = weight_map[tensor_name]
if file_name not in loaded_files: if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name) file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda") loaded_files[file_name] = load_file(file_path, device=default_device)
return loaded_files[file_name][tensor_name] return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort() safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files): for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file) file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda") current_state_dict = load_file(safetensor_file, device=default_device)
loaded_files[file_name] = current_state_dict loaded_files[file_name] = current_state_dict
new_state_dict = {} new_state_dict = {}
@ -91,7 +98,12 @@ def main(fp8_path, bf16_path):
if len(loaded_files) > 2: if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files)) oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file] del loaded_files[oldest_file]
if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif torch.mps.is_available():
torch.mps.empty_cache()
else:
gc.collect()
# Update model index # Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")

View File

@ -30,10 +30,11 @@ def sample(logits, temperature: float = 1.0):
@torch.inference_mode() @torch.inference_mode()
def generate( def generate(
model: Transformer, model: Transformer,
device: str,
prompt_tokens: List[List[int]], prompt_tokens: List[List[int]],
max_new_tokens: int, max_new_tokens: int,
eos_id: int, eos_id: int,
temperature: float = 1.0 temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
""" """
Generates new tokens based on the given prompt tokens using the specified model. Generates new tokens based on the given prompt tokens using the specified model.
@ -51,11 +52,11 @@ def generate(
prompt_lens = [len(t) for t in prompt_tokens] prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})" assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device=device)
for i, t in enumerate(prompt_tokens): for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)
prev_pos = 0 prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda") finished = torch.tensor([False] * len(prompt_tokens), device=device)
prompt_mask = tokens != -1 prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len): for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
@ -97,11 +98,20 @@ def main(
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100. max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0. temperature (float, optional): Temperature for sampling. Defaults to 1.0.
""" """
if torch.cuda.is_available():
default_device = "cuda"
elif torch.mps.is_available():
default_device = "mps"
else:
default_device = "cpu"
world_size = int(os.getenv("WORLD_SIZE", "1")) world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0")) rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1: if world_size > 1:
if torch.cuda.is_available():
dist.init_process_group("nccl") dist.init_process_group("nccl")
else:
dist.init_process_group("gloo")
global print global print
if rank != 0: if rank != 0:
print = lambda *_, **__: None print = lambda *_, **__: None
@ -112,10 +122,10 @@ def main(
with open(config) as f: with open(config) as f:
args = ModelArgs(**json.load(f)) args = ModelArgs(**json.load(f))
print(args) print(args)
with torch.device("cuda"): with torch.device(default_device):
model = Transformer(args) model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path) tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0]) tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0], default_device)
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
if interactive: if interactive:

View File

@ -796,7 +796,13 @@ class Transformer(nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda") if torch.cuda.is_available():
default_device = "cuda"
elif torch.mps.is_available():
default_device = "mps"
else:
default_device = "cpu"
torch.set_default_device(default_device)
torch.manual_seed(0) torch.manual_seed(0)
args = ModelArgs() args = ModelArgs()
x = torch.randint(0, args.vocab_size, (2, 128)) x = torch.randint(0, args.vocab_size, (2, 128))

View File

@ -1,4 +1,4 @@
torch==2.4.1 torch==2.6.0
triton==3.0.0 git+https://github.com/NripeshN/triton.git@main#subdirectory=python
transformers==4.46.3 transformers==4.46.3
safetensors==0.4.5 safetensors==0.4.5