mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-28 18:09:22 +00:00
Compare commits
4 Commits
a878eada08
...
88d6547df2
Author | SHA1 | Date | |
---|---|---|---|
|
88d6547df2 | ||
|
741b06ebca | ||
|
a5d2ad229e | ||
|
d29a967601 |
@ -321,7 +321,7 @@ For comprehensive step-by-step instructions on running DeepSeek-V3 with LMDeploy
|
|||||||
|
|
||||||
### 6.4 Inference with TRT-LLM (recommended)
|
### 6.4 Inference with TRT-LLM (recommended)
|
||||||
|
|
||||||
[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) now supports the DeepSeek-V3 model, offering precision options such as BF16 and INT4/INT8 weight-only. Support for FP8 is currently in progress and will be released soon. You can access the custom branch of TRTLLM specifically for DeepSeek-V3 support through the following link to experience the new features directly: https://github.com/NVIDIA/TensorRT-LLM/tree/deepseek/examples/deepseek_v3.
|
[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) now supports the DeepSeek-V3 model, offering precision options such as BF16 and INT4/INT8 weight-only. Support for FP8 is currently in progress and will be released soon. You can access the custom branch of TRTLLM specifically for DeepSeek-V3 support through the following link to experience the new features directly: https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/deepseek_v3.
|
||||||
|
|
||||||
|
|
||||||
### 6.5 Inference with vLLM (recommended)
|
### 6.5 Inference with vLLM (recommended)
|
||||||
|
@ -392,7 +392,7 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
|||||||
|
|
||||||
class MLA(nn.Module):
|
class MLA(nn.Module):
|
||||||
"""
|
"""
|
||||||
Multi-Headed Attention Layer (MLA).
|
Multi-Head Latent Attention (MLA) Layer.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
dim (int): Dimensionality of the input features.
|
dim (int): Dimensionality of the input features.
|
||||||
@ -442,7 +442,7 @@ class MLA(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
||||||
"""
|
"""
|
||||||
Forward pass for the Multi-Headed Attention Layer (MLA).
|
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
|
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
|
||||||
|
Loading…
x
Reference in New Issue
Block a user