lit-llama代码解析

2024-09-04 05:52
文章标签 代码 解析 llama lit

本文主要是介绍lit-llama代码解析,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

https://github.com/Lightning-AI/lit-llama/blob/main/README.md

下载的时候会报错误,因为网不行,一种方法就是多次尝试,另一种方法是终端连上代理下载

pycharm连接hugging face等网站_hugging face怎么连接-CSDN博客

根据指引下载权重

下载完权重运行:python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/open-llama/7B --model_size 7B

转化为.pth文件 

跟着readme/howto教程量化或进行其他操作

warning

UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at ..\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:455.)
  y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

https://github.com/comfyanonymous/ComfyUI/issues/3202

分析generate

# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.import sys
import time
import warnings
from pathlib import Path
from typing import Optionalimport lightning as L
import torch
print(torch.cuda.is_available())
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import lazy_load, llama_model_lookup, quantization@torch.no_grad()
def generate(model: LLaMA,idx: torch.Tensor,max_new_tokens: int,*,max_seq_length: Optional[int] = None,temperature: float = 1.0,top_k: Optional[int] = None,eos_id: Optional[int] = None,
) -> torch.Tensor:"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.The implementation of this function is modified from A. Karpathy's nanoGPT.Args:model: The model to use.idx: Tensor of shape (T) with indices of the prompt sequence.max_new_tokens: The number of new tokens to generate.max_seq_length: The maximum sequence length allowed.temperature: Scales the predicted logits by 1 / temperaturetop_k: If specified, only sample among the tokens with the k highest probabilitieseos_id: If specified, stop generating any more token once the <eos> token is triggered"""# create an empty tensor of the expected final shape and fill in the current tokensT = idx.size(0)T_new = T + max_new_tokensif max_seq_length is None:max_seq_length = min(T_new, model.config.block_size)device, dtype = idx.device, idx.dtype# create an empty tensor of the expected final shape and fill in the current tokensempty = torch.empty(T_new, dtype=dtype, device=device)empty[:T] = idxidx = emptyinput_pos = torch.arange(0, T, device=device)if idx.device.type == "xla":import torch_xla.core.xla_model as xmxm.mark_step()# generate max_new_tokens tokensfor _ in range(max_new_tokens):x = idx.index_select(0, input_pos).view(1, -1)# forwardlogits = model(x, max_seq_length, input_pos)logits = logits[0, -1] / temperature# optionally crop the logits to only the top k optionsif top_k is not None:v, _ = torch.topk(logits, min(top_k, logits.size(-1)))logits = torch.where(logits < v[[-1]], -float("Inf"), logits)probs = torch.nn.functional.softmax(logits, dim=-1)idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)# advanceinput_pos = input_pos[-1:] + 1if idx.device.type == "xla":xm.mark_step()# concatenate the new generationidx = idx.index_copy(0, input_pos, idx_next)# if <eos> token is triggered, return the output (stop generation)if idx_next == eos_id:return idx[:input_pos]  # include the EOS tokenreturn idxdef main(prompt: str = "Hello, my name is",*,num_samples: int = 1,max_new_tokens: int = 50,top_k: int = 200,temperature: float = 0.8,checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),quantize: Optional[str] = None,
) -> None:"""Generates text samples based on a pre-trained LLaMA model and tokenizer.Args:prompt: The prompt string to use for generating the samples.num_samples: The number of text samples to generate.(Its effect is overridden by `max_new_tokens`, if also set.)max_new_tokens: The number of generation steps to take.(number of generate tokens )top_k: The number of top most probable tokens to consider in the sampling process.temperature: A value controlling the randomness of the sampling process. Higher values result in more randomsamples.checkpoint_path: The checkpoint path to load.tokenizer_path: The tokenizer path to load.quantize: Whether to quantize the model and using which method:``"llm.int8"``: LLM.int8() mode,``"gptq.int4"``: GPTQ 4-bit mode."""assert checkpoint_path.is_file(), checkpoint_pathassert tokenizer_path.is_file(), tokenizer_pathprecision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"fabric = L.Fabric(devices=1, precision=precision)print("Loading model ...", file=sys.stderr)t0 = time.time()with lazy_load(checkpoint_path) as checkpoint:name = llama_model_lookup(checkpoint)with fabric.init_module(empty_init=True), quantization(mode=quantize):model = LLaMA.from_name(name)model.load_state_dict(checkpoint)print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)model.eval()model = fabric.setup(model)tokenizer = Tokenizer(tokenizer_path)encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)prompt_length = encoded.size(0)L.seed_everything(1234)for i in range(num_samples):t0 = time.perf_counter()y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)t = time.perf_counter() - t0model.reset_cache()print(tokenizer.decode(y))tokens_generated = y.size(0) - prompt_lengthprint(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)if fabric.device.type == "cuda":print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)if __name__ == "__main__":from jsonargparse import CLItorch.set_float32_matmul_precision("high")warnings.filterwarnings(# Triggered internally at ../aten/src/ATen/EmptyTensor.cpp:31"ignore", message="ComplexHalf support is experimental and many operators don't support it yet")warnings.filterwarnings(# Triggered in bitsandbytes/autograd/_functions.py:298"ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization",)CLI(main)

main()

"""Generates text samples based on a pre-trained LLaMA model and tokenizer.Args:prompt: The prompt string to use for generating the samples.num_samples: The number of text samples to generate.(Its effect is overridden by `max_new_tokens`, if also set.)max_new_tokens: The number of generation steps to take.(number of generate tokens )top_k: The number of top most probable tokens to consider in the sampling process.temperature: A value controlling the randomness of the sampling process. Higher values result in more random samples.checkpoint_path: The checkpoint path to load.tokenizer_path: The tokenizer path to load.quantize: Whether to quantize the model and using which method:``"llm.int8"``: LLM.int8() mode,``"gptq.int4"``: GPTQ 4-bit mode.
"""


https://zhuanlan.zhihu.com/p/657886517

Fabric()

r"""Fabric accelerates your PyTorch training or inference code with minimal changes required.Fabric 加速你的 PyTorch 训练或推理代码,所需的更改最小。- Automatic placement of models and data onto the device.- 自动将模型和数据放置到设备上。- Automatic support for mixed and double precision (smaller memory footprint).- 自动支持混合精度和双精度(较小的内存占用)。- Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies(data-parallel training, sharded training, etc.).- 在硬件(CPU、GPU、TPU)和分布式训练策略(数据并行训练、分片训练等)之间无缝切换。- Automated spawning of processes, no launch utilities required.- 自动生成进程,无需启动工具。- Multi-node support.- 支持多节点训练。Args:accelerator: The hardware to run on. Possible choices are:``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.accelerator: 运行的硬件。可能的选择有:``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``。strategy: Strategy for how to run across multiple devices. Possible choices are:``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.strategy: 跨多个设备运行的策略。可能的选择有:``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``。devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.The value applies per node.devices: 训练时使用的设备数量(``int``),或要训练的 GPU(``list`` 或 ``str``),或 ``"auto"``。该值适用于每个节点。num_nodes: Number of GPU nodes for distributed training.num_nodes: 用于分布式训练的 GPU 节点数量。precision: Double precision (``"64"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),or bfloat16 precision AMP (``"bf16-mixed"``).precision: 双精度(``"64"``),全精度(``"32"``),半精度 AMP(``"16-mixed"``),或 bfloat16 精度 AMP(``"bf16-mixed"``)。plugins: One or several custom pluginsplugins: 一个或多个自定义插件callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods thatcan be invoked through :meth:`~lightning.fabric.fabric.Fabric.call` by the user.callbacks: 单个回调或回调列表。回调可以包含任何用户可以通过 :meth:`~lightning.fabric.fabric.Fabric.call` 调用的任意方法。loggers: A single logger or a list of loggers. See :meth:`~lightning.fabric.fabric.Fabric.log` for moreinformation.loggers: 单个日志记录器或日志记录器列表。有关更多信息,请参见 :meth:`~lightning.fabric.fabric.Fabric.log`。
"""

lazy_load()

定义了一个名为 lazy_load 的类,它用于延迟加载和管理一个 PyTorch 文件:lazy_load 类
__init__ 方法
python
def __init__(self, fn):self.zf = torch._C.PyTorchFileReader(str(fn))with BytesIO(self.zf.get_record("data.pkl")) as pkl:mup = LazyLoadingUnpickler(pkl, self)self.sd = mup.load()
self.zf = torch._C.PyTorchFileReader(str(fn)):创建一个 PyTorchFileReader 实例,用于读取指定文件 (fn) 的内容。这个文件是 PyTorch 保存的文件,通常是 .pt 或 .pth 文件。
str(fn) 确保文件路径被正确转换为字符串。
with BytesIO(self.zf.get_record("data.pkl")) as pkl::从 PyTorchFileReader 中提取名为 "data.pkl" 的记录,并用 BytesIO 创建一个内存中的字节流对象 pkl。
BytesIO 用于在内存中读写二进制数据。
mup = LazyLoadingUnpickler(pkl, self):创建一个 LazyLoadingUnpickler 实例 mup,它负责处理 pkl 中的数据。这里假设 LazyLoadingUnpickler 是自定义的类,用于延迟加载和解码 Pickle 数据。
self.sd = mup.load():调用 mup.load() 方法来加载数据,并将结果存储在 self.sd 属性中。这个过程可能会涉及到数据的反序列化。
__enter__ 方法
python
def __enter__(self):return self.sd
这个方法允许 lazy_load 实例在上下文管理器(with 语句)中使用。__enter__ 返回 self.sd,使得 with 语句块内部可以直接访问加载的数据。
__exit__ 方法
python
def __exit__(self, exc_type, exc_val, exc_tb):del self.zf  # I don't think there is a way to force closing...self.zf = None
这个方法用于处理退出上下文管理器时的清理工作。
del self.zf: 尝试删除 self.zf 对象。由于 self.zf 是一个 PyTorchFileReader 实例,删除对象的作用是释放相关资源。
self.zf = None: 另一种释放资源的方式,将 self.zf 设置为 None,以确保它不再被引用。
总结
这个类的设计用于懒加载 PyTorch 文件中的数据。它实现了上下文管理协议,使得数据可以在 with 语句块中方便地访问,并且在退出时尝试释放相关资源。

LazyLoadingUnpickler()

定义了一个 LazyLoadingUnpickler 类,继承自 pickle.Unpickler,用于处理 PyTorch 对象的延迟加载。以下是对每个部分的详细解释:__init__ 方法
python
def __init__(self, file, zipfile_context):super().__init__(file)self.zipfile_context = zipfile_context
file: 传入的文件对象(通常是一个字节流),用于反序列化。
zipfile_context: 额外的上下文信息,用于延迟加载的实现。这通常是一个包含 PyTorch 文件读取信息的对象。
super().__init__(file): 调用父类 pickle.Unpickler 的初始化方法,传入文件对象。
self.zipfile_context: 保存额外的上下文信息,用于稍后延迟加载。
find_class 方法
python
def find_class(self, module, name):res = super().find_class(module, name)if module == "torch._utils" and name == "_rebuild_tensor_v2":return functools.partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self)elif module == "torch._tensor" and name == "_rebuild_from_type_v2":return functools.partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self)elif module == "torch._utils" and name == "_rebuild_parameter":return functools.partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self)return res
super().find_class(module, name): 调用父类的 find_class 方法,查找并返回指定模块和类名的类。
模块和类名检查:
当模块是 "torch._utils" 且类名是 "_rebuild_tensor_v2" 时,返回一个 functools.partial 对象,部分应用 NotYetLoadedTensor.rebuild_tensor_v2 方法,并传入 archiveinfo=self。
当模块是 "torch._tensor" 且类名是 "_rebuild_from_type_v2" 时,返回一个 functools.partial 对象,部分应用 NotYetLoadedTensor.rebuild_from_type_v2 方法。
当模块是 "torch._utils" 且类名是 "_rebuild_parameter" 时,返回一个 functools.partial 对象,部分应用 NotYetLoadedTensor.rebuild_parameter 方法。
functools.partial: 允许创建一个新的函数,其中一些参数已经预先指定,这里是为了在实际调用时延迟具体的处理逻辑。
返回值: 如果模块和类名不匹配,返回父类的结果。
persistent_load 方法
python
def persistent_load(self, pid):name, cls, fn, device, size = pidwith warnings.catch_warnings():warnings.simplefilter("ignore")s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta")s.archiveinfo = pidreturn s
pid: 一个包含多个信息的元组 (name, cls, fn, device, size),用于标识持久化数据的加载信息。
warnings.catch_warnings(): 捕获并管理警告信息。
warnings.simplefilter("ignore"): 忽略警告信息,以便在加载过程中不会产生干扰。
torch.storage.TypedStorage(dtype=cls().dtype, device="meta"): 创建一个 TypedStorage 对象,指定数据类型和设备。device="meta" 表示数据存储在元数据设备中,实际上并没有分配真实的存储空间。
s.archiveinfo = pid: 将持久化标识信息存储到 TypedStorage 对象中。
返回值: 返回创建的 TypedStorage 对象。
总结
LazyLoadingUnpickler 主要用于在反序列化 PyTorch 对象时实现延迟加载。这种方法使得在加载大数据文件时可以更高效地管理内存和计算资源。find_class 方法用于动态创建用于延迟加载的对象,而 persistent_load 方法则用于处理持久化存储数据的加载。

llama_model_lookup() 

init_module() 

def init_module(self, empty_init: Optional[bool] = None) -> ContextManager:"""Instantiate the model and its parameters under this context manager to reduce peak memory usage.
在这个上下文管理器下实例化模型及其参数,以减少峰值内存使用。The parameters get created on the device and with the right data type right away without wasting memory being allocated unnecessarily.
参数会直接在设备上创建,并且使用正确的数据类型,从而避免了不必要的内存分配浪费。Args:
参数:empty_init: Whether to initialize the model with empty weights (uninitialized memory).
empty_init: 是否使用空权重(未初始化的内存)来初始化模型。If ``None``, the strategy will decide. Some strategies may not support all options.
如果``None``,则策略将决定。一些策略可能不支持所有选项。Set this to ``True`` if you are loading a checkpoint into a large model.
如果你正在将检查点加载到大型模型中,将其设置为``True``。"""self._validate_launched()return self._strategy.module_init_context(empty_init=empty_init)
module_init_context()  
 def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:"""A context manager wrapping the model instantiation.
一个包装模型实例化的上下文管理器。Here, the strategy can control how the parameters of the model get created (device, dtype) and or apply other patches to the model.
在这里,策略可以控制模型参数的创建方式(设备、数据类型)或对模型应用其他修补。Args:
参数:empty_init: Whether to initialize the model with empty weights (uninitialized memory).
empty_init: 是否使用空权重(未初始化的内存)来初始化模型。If ``None``, the strategy will decide. Some strategies may not support all options.
如果``None``,则策略将决定。一些策略可能不支持所有选项。"""precision_module_ctx = self.precision.module_init_context()stack = ExitStack()stack.enter_context(self.root_device)stack.enter_context(_EmptyInit(enabled=bool(empty_init)))stack.enter_context(precision_module_ctx)return stack

quantization() 

@contextmanager
def quantization(mode: str = None):quantized_linear_cls = Noneif mode == 'llm.int8':from .quantization import Linear8bitLtquantized_linear_cls = Linear8bitLtelif mode == 'gptq.int4':from .quantization import ColBlockQuantizedLinearquantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1)elif mode == 'gptq.int8':from .quantization import ColBlockQuantizedLinearquantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1)elif mode is not None:raise ValueError(f"Unknown quantization mode: {mode}")enabled = mode is not Nonetorch_linear_cls = torch.nn.Linearif enabled:torch.nn.Linear = quantized_linear_clsyieldif enabled:torch.nn.Linear = torch_linear_cls

model 

setup() 

    def setup(self,module: nn.Module,*optimizers: Optimizer,move_to_device: bool = True,_reapply_compile: bool = True,) -> Any:  # no specific return because the way we want our API to look does not play well with mypyr"""Set up a model and its optimizers for accelerated training.
为加速训练设置模型及其优化器。Args:
参数:module: A :class:`torch.nn.Module` to set up
module: 要设置的 :class:`torch.nn.Module`*optimizers: The optimizer(s) to set up (no optimizers is also possible)
*optimizers: 要设置的优化器(也可以不设置优化器)move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
move_to_device: 如果设置为``True``(默认值),则将模型移动到正确的设备。设置为``False`` and alternatively use :meth:`to_device` manually.并可以手动使用 :meth:`to_device`。_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
_reapply_compile: 如果``True``(默认值),且模型之前已``torch.compile``,则corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the相应的 :class:`~torch._dynamo.OptimizedModule` 包装器将被移除,并在模型被策略设置好后重新应用same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,相同的设置(例如,模型被 DDP、FSDP 等包装之后)。如果编译 DDP/FSDP 造成问题,设置为``False``。Returns:
返回:The tuple containing wrapped module and the optimizers, in the same order they were passed in.
一个包含包装的模块和优化器的元组,顺序与传入时相同。"""

tokenizer

 

    def encode(self,string: str,bos: bool = True,eos: bool = False,max_length: int = -1,pad: bool = False,device: Optional[torch.device] = None) -> torch.Tensor:tokens = self.processor.encode(string)if bos:tokens = [self.bos_id] + tokensif eos:tokens = tokens + [self.eos_id]if max_length > 0:tokens = tokens[:max_length]if pad and len(tokens) < max_length:tokens += [self.pad_id] * (max_length - len(tokens))return torch.tensor(tokens, dtype=torch.int, device=device)def decode(self, tokens: torch.Tensor) -> str:return self.processor.decode(tokens.tolist())

 generate()

@torch.no_grad()
def generate(model: LLaMA,idx: torch.Tensor,max_new_tokens: int,*,max_seq_length: Optional[int] = None,temperature: float = 1.0,top_k: Optional[int] = None,eos_id: Optional[int] = None,
) -> torch.Tensor:"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
接收一个条件序列(提示)作为输入,并继续生成所请求的数量的标记。The implementation of this function is modified from A. Karpathy's nanoGPT.
此函数的实现改编自 A. Karpathy 的 nanoGPT。Args:
参数:model: The model to use.
model: 要使用的模型。idx: Tensor of shape (T) with indices of the prompt sequence.
idx: 形状为 (T) 的张量,其中包含提示序列的索引。max_new_tokens: The number of new tokens to generate.
max_new_tokens: 要生成的新分词数量。max_seq_length: The maximum sequence length allowed.
max_seq_length: 允许的最大序列长度。temperature: Scales the predicted logits by 1 / temperature
temperature: 通过 1 / temperature 对预测的 logits 进行缩放。top_k: If specified, only sample among the tokens with the k highest probabilities
top_k: 如果指定,只从概率最高的 k 个标记中进行采样。eos_id: If specified, stop generating any more token once the <eos> token is triggered
eos_id: 如果指定,一旦触发 <eos> 标记,停止生成更多标记。"""

 https://pytorch.ac.cn/xla/release/2.1/index.htmlXLA 设备上的 PyTorch

model

    def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:return build_rope_cache(seq_len=self.config.block_size,n_elem=self.config.n_embd // self.config.n_head,dtype=idx.dtype,device=idx.device,)

temperature

温度越低,结果的差距越大,会使概率分布更加尖锐,从而使得模型更倾向于选择最高概率的类别。

topk()  

def topk(input: Tensor, k: Union[_int, SymInt], dim: _int = -1, largest: _bool = True, sorted: _bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None) -> torch.return_types.topk: r"""topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor)返回给定 input 张量在指定维度上最大的 k 个元素。如果没有给定 dim,则选择 input 张量的最后一个维度。如果 largest 设置为 False,则返回 k 个最小元素。函数返回一个命名元组 (values, indices),其中 values 和 indices 分别是输入张量在指定维度 dim 上最大的 k 个元素及其索引。布尔选项 sorted 如果为 True,则确保返回的 k 个元素按顺序排列。参数:input (Tensor): 输入张量。
k (int): "top-k" 中的 k 值。
dim (int, optional): 排序的维度。
largest (bool, optional): 控制是否返回最大还是最小元素。
sorted (bool, optional): 控制是否返回排序后的元素。
关键字参数:out (tuple, optional): 可选的输出元组 (Tensor, LongTensor),可以作为输出缓冲区使用。
示例:python
>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))"""

torch.multinomial

def multinomial(input: Tensor, num_samples: _int, replacement: _bool = False, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor: r"""def multinomial(input: Tensor, num_samples: _int, replacement: _bool = False, *, generator: Optional[Generator] = None, out: Optional[Tensor] = None) -> Tensor:r"""multinomial(input, num_samples, replacement=False, *, generator=None, out=None) -> LongTensor返回一个张量,其中每一行包含 :attr:`num_samples` 个从对应行的多项分布中采样的索引。更严格地说,是从多元分布中采样,更多细节请参考 torch.distributions.multinomial.Multinomial。.. note:::attr:`input` 的行不需要和为 1(在这种情况下,我们使用值作为权重),但必须是非负的、有穷的,并且和不为零。索引按从左到右的顺序排列,依据每个索引被采样的顺序(第一个样本放在第一列)。如果 :attr:`input` 是一个向量,:attr:`out` 是一个大小为 :attr:`num_samples` 的向量。如果 :attr:`input` 是一个有 `m` 行的矩阵,则 :attr:`out` 是一个形状为:math:`(m \times \text{num\_samples})` 的矩阵。如果 `replacement` 为 ``True``,则样本是有放回的。如果不是,则样本是无放回的,这意味着一旦为某行绘制了一个样本索引,在该行中不能再次绘制相同的索引。.. note::当无放回采样时,:attr:`num_samples` 必须小于 :attr:`input` 中非零元素的数量(如果 `input` 是矩阵,则为每行的非零元素的最小数量)。Args:input (Tensor): 包含概率的输入张量num_samples (int): 要绘制的样本数量replacement (bool, optional): 是否允许重复抽样关键字参数:generator (:class:`torch.Generator`, optional): 用于采样的伪随机数生成器out (Tensor, optional): 输出张量。示例::>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # 创建一个权重张量>>> torch.multinomial(weights, 2)tensor([1, 2])>>> torch.multinomial(weights, 4) # 错误!RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False,not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320>>> torch.multinomial(weights, 4, replacement=True)tensor([ 2,  1,  1,  1])""""""

 model.reset_cache()

 

Pytorch清空显存缓冲区(torch.cuda.empty_cache)_pytorch 清空显存-CSDN博客 
Pytorch 如何在使用模型后清除GPU内存|极客教程

这篇关于lit-llama代码解析的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1135179

相关文章

深度解析Python中递归下降解析器的原理与实现

《深度解析Python中递归下降解析器的原理与实现》在编译器设计、配置文件处理和数据转换领域,递归下降解析器是最常用且最直观的解析技术,本文将详细介绍递归下降解析器的原理与实现,感兴趣的小伙伴可以跟随... 目录引言:解析器的核心价值一、递归下降解析器基础1.1 核心概念解析1.2 基本架构二、简单算术表达

深度解析Java @Serial 注解及常见错误案例

《深度解析Java@Serial注解及常见错误案例》Java14引入@Serial注解,用于编译时校验序列化成员,替代传统方式解决运行时错误,适用于Serializable类的方法/字段,需注意签... 目录Java @Serial 注解深度解析1. 注解本质2. 核心作用(1) 主要用途(2) 适用位置3

Java MCP 的鉴权深度解析

《JavaMCP的鉴权深度解析》文章介绍JavaMCP鉴权的实现方式,指出客户端可通过queryString、header或env传递鉴权信息,服务器端支持工具单独鉴权、过滤器集中鉴权及启动时鉴权... 目录一、MCP Client 侧(负责传递,比较简单)(1)常见的 mcpServers json 配置

Redis实现高效内存管理的示例代码

《Redis实现高效内存管理的示例代码》Redis内存管理是其核心功能之一,为了高效地利用内存,Redis采用了多种技术和策略,如优化的数据结构、内存分配策略、内存回收、数据压缩等,下面就来详细的介绍... 目录1. 内存分配策略jemalloc 的使用2. 数据压缩和编码ziplist示例代码3. 优化的

从原理到实战解析Java Stream 的并行流性能优化

《从原理到实战解析JavaStream的并行流性能优化》本文给大家介绍JavaStream的并行流性能优化:从原理到实战的全攻略,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的... 目录一、并行流的核心原理与适用场景二、性能优化的核心策略1. 合理设置并行度:打破默认阈值2. 避免装箱

Maven中生命周期深度解析与实战指南

《Maven中生命周期深度解析与实战指南》这篇文章主要为大家详细介绍了Maven生命周期实战指南,包含核心概念、阶段详解、SpringBoot特化场景及企业级实践建议,希望对大家有一定的帮助... 目录一、Maven 生命周期哲学二、default生命周期核心阶段详解(高频使用)三、clean生命周期核心阶

Python 基于http.server模块实现简单http服务的代码举例

《Python基于http.server模块实现简单http服务的代码举例》Pythonhttp.server模块通过继承BaseHTTPRequestHandler处理HTTP请求,使用Threa... 目录测试环境代码实现相关介绍模块简介类及相关函数简介参考链接测试环境win11专业版python

Python从Word文档中提取图片并生成PPT的操作代码

《Python从Word文档中提取图片并生成PPT的操作代码》在日常办公场景中,我们经常需要从Word文档中提取图片,并将这些图片整理到PowerPoint幻灯片中,手动完成这一任务既耗时又容易出错,... 目录引言背景与需求解决方案概述代码解析代码核心逻辑说明总结引言在日常办公场景中,我们经常需要从 W

深入解析C++ 中std::map内存管理

《深入解析C++中std::map内存管理》文章详解C++std::map内存管理,指出clear()仅删除元素可能不释放底层内存,建议用swap()与空map交换以彻底释放,针对指针类型需手动de... 目录1️、基本清空std::map2️、使用 swap 彻底释放内存3️、map 中存储指针类型的对象

Java Scanner类解析与实战教程

《JavaScanner类解析与实战教程》JavaScanner类(java.util包)是文本输入解析工具,支持基本类型和字符串读取,基于Readable接口与正则分隔符实现,适用于控制台、文件输... 目录一、核心设计与工作原理1.底层依赖2.解析机制A.核心逻辑基于分隔符(delimiter)和模式匹