Source code for torchhd.memory

#
# MIT License
#
# Copyright (c) 2023 Mike Heddes, Igor Nunes, Pere Vergés, Denis Kleyko, and Danny Abraham
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import math
from typing import Optional
import torch
from torch import Tensor
import torch.nn as nn
from scipy.stats import binom

import torchhd.functional as functional
from torchhd.tensors.base import VSATensor

__all__ = [
    "SparseDistributed",
    "hopfield",
    "Hopfield",
    "modern_hopfield",
    "attention",
]


[docs] class SparseDistributed(nn.Module): r"""`Sparse Distributed Memory <https://redwood.berkeley.edu/wp-content/uploads/2020/08/KanervaP_SDMrelated_models1993.pdf>`_ The Sparse Distributed Memory (SDM) is specified by its (typically random) keys and their values. Args: memory_size (int): The number of memory key-value pairs. key_dim (int): The dimensionality of the key vectors. value_dim (int): The dimensionality of the value vectors. p (float, optional): The expected fraction of memory address that will contain any value. Default: ``0.000368``. kappa (int, optional): The maximum count for each memory cell, values are clipped between [-kappa, kappa]. Default: no clipping. dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``None`` depends on VSATensor. device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``. Shapes: - Keys: :math:`(n, a)` - Values: :math:`(n, c)` Examples:: >>> keys = torchhd.random(6, 512) >>> sdm = torchhd.memory.SparseDistributed(100000, 512, 512) >>> # use as associative memory >>> sdm.write(keys, keys) >>> read = sdm.read(keys).sign() >>> torchhd.cosine_similarity(read, keys) tensor([[ 1.0000, 0.0156, -0.0039, -0.0742, 0.0000, -0.0195], [ 0.0156, 1.0000, -0.0352, -0.0586, 0.0000, -0.0039], [-0.0039, -0.0352, 1.0000, 0.0156, 0.0820, -0.0234], [-0.0742, -0.0586, 0.0156, 1.0000, -0.0039, 0.0000], [ 0.0000, 0.0000, 0.0820, -0.0039, 1.0000, 0.0195], [-0.0195, -0.0039, -0.0234, 0.0000, 0.0195, 1.0000]]) """ memory_size: int key_dim: int value_dim: int keys: VSATensor values: VSATensor threshold: int kappa: Optional[int] def __init__( self, memory_size: int, key_dim: int, value_dim: int, p: float = 0.000368, kappa: Optional[int] = None, dtype=None, device=None, requires_grad=False, ) -> None: super().__init__() self.memory_size = memory_size self.key_dim = key_dim self.value_dim = value_dim radius = int(binom.ppf(p, key_dim, 0.5)) self.threshold = key_dim - 2 * radius self.kappa = kappa keys = functional.random(memory_size, key_dim, dtype=dtype, device=device) self.keys = nn.Parameter(keys, requires_grad) values = functional.empty(memory_size, value_dim, device=device, dtype=dtype) self.values = nn.Parameter(values, requires_grad)
[docs] def read(self, query: Tensor) -> VSATensor: r"""Read value from Sparse Distributed Memory whose key is most similar to the query. Args: query (Tensor): The query vector for the memory lookup. Shapes: - Query: :math:`(*, d)` - Result: :math:`(*, d)` """ # first dims from query, last dim from value out_shape = tuple(query.shape[:-1]) + (self.value_dim,) if query.dim() == 1: query = query.unsqueeze(0) intermediate_shape = tuple(query.shape[:-1]) + (self.value_dim,) similarity = query @ self.keys.T is_active = similarity >= self.threshold # Sparse matrix-vector multiplication. to_indices, from_indices = is_active.nonzero().T read = torch.zeros(intermediate_shape, dtype=query.dtype, device=query.device) read.index_add_(0, to_indices, self.values[from_indices]) return read.view(out_shape).as_subclass(functional.MAPTensor)
[docs] @torch.no_grad() def write(self, keys: Tensor, values: Tensor) -> None: r"""Write value to Sparse Distributed Memory at address. Args: address (Tensor): The address vector for the write to memory. value (Tensor): The value vector written to memory. Shapes: - Address: :math:`(*, d)` - Value: :math:`(*, d)` """ if keys.dim() == 1: keys = keys.unsqueeze(0) if values.dim() == 1: values = values.unsqueeze(0) similarity = keys @ self.keys.T is_active = similarity >= self.threshold # Sparse outer product and addition. from_indices, to_indices = is_active.nonzero().T self.values.index_add_(0, to_indices, values[from_indices]) if self.kappa is not None: self.values.clamp_(-self.kappa, self.kappa)
[docs] def hopfield(query: Tensor, memory: Tensor, kappa: int = None) -> Tensor: r"""`Classical Hopfield network <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC346238/>`_ Args: query (Tensor): The query vector for the memory lookup. memory (Tensor): The items of memory for the memory lookup. Shapes: - Query: :math:`(*, d)` - Memory: :math:`(n, d)` - Result: :math:`(*, d)` Examples:: >>> items = torchhd.random(6, 512) >>> read = memory.hopfield(items, items).sign() >>> torchhd.cosine_similarity(read, items) tensor([[ 1.0000, 0.0156, -0.0039, -0.0742, 0.0000, -0.0195], [ 0.0156, 1.0000, -0.0352, -0.0586, 0.0000, -0.0039], [-0.0039, -0.0352, 1.0000, 0.0156, 0.0820, -0.0234], [-0.0742, -0.0586, 0.0156, 1.0000, -0.0039, 0.0000], [ 0.0000, 0.0000, 0.0820, -0.0039, 1.0000, 0.0195], [-0.0195, -0.0039, -0.0234, 0.0000, 0.0195, 1.0000]]) """ product = memory.T @ memory torch.diagonal(product).zero_() if kappa is not None: product.clamp_(-kappa, kappa) return query @ product
[docs] class Hopfield(nn.Module): r"""`Classical Hopfield network <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC346238/>`_ Args: vector_dim (int): The dimensionality of the vectors in the memory. kappa (int, optional): The maximum count for each memory cell, values are clipped between [-kappa, kappa]. Default: no clipping. dtype (``torch.dtype``, optional): the desired data type of returned tensor. Default: if ``None`` depends on VSATensor. device (``torch.device``, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see torch.set_default_tensor_type()). ``device`` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types. requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``. Shapes: - Memory: :math:`(d, d)` Examples:: >>> items = torchhd.random(6, 512) >>> hopfield = torchhd.memory.Hopfield(512) >>> hopfield.write(items) >>> read = hopfield.read(items).sign() >>> torchhd.cosine_similarity(read, items) tensor([[ 1.0000, 0.0156, -0.0039, -0.0742, 0.0000, -0.0195], [ 0.0156, 1.0000, -0.0352, -0.0586, 0.0000, -0.0039], [-0.0039, -0.0352, 1.0000, 0.0156, 0.0820, -0.0234], [-0.0742, -0.0586, 0.0156, 1.0000, -0.0039, 0.0000], [ 0.0000, 0.0000, 0.0820, -0.0039, 1.0000, 0.0195], [-0.0195, -0.0039, -0.0234, 0.0000, 0.0195, 1.0000]]) """ vector_dim: int memory: Tensor kappa: Optional[int] def __init__( self, vector_dim: int, kappa: Optional[int] = None, dtype=None, device=None, requires_grad=False, ) -> None: super().__init__() self.vector_dim = vector_dim self.kappa = kappa memory = torch.zeros( self.vector_dim, self.vector_dim, device=device, dtype=dtype ) self.memory = nn.Parameter(memory, requires_grad)
[docs] def read(self, query: Tensor) -> Tensor: r"""Read value from Hopfield network at key most similar to the query. Args: query (Tensor): The query vector for the memory lookup. Shapes: - Query: :math:`(*, d)` - Result: :math:`(*, d)` """ return query @ self.memory.T
[docs] @torch.no_grad() def write(self, items: Tensor) -> Tensor: r"""Write items to Hopfield Memory. Args: items (Tensor): The item vectors to write to memory. Shapes: - Items: :math:`(*, d)` """ if items.dim() == 1: items = items.unsqueeze(0) # Add the outer product to memory self.memory.add_(items.T @ items) torch.diagonal(self.memory).zero_() if self.kappa is not None: self.memory.clamp_(-self.kappa, self.kappa)
[docs] def modern_hopfield(query: Tensor, memory: Tensor) -> Tensor: r"""`Modern Hopfield network <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC346238/>`_ Also known as Dense Associative Memory. Args: query (Tensor): The query vector for the memory lookup. memory (Tensor): The items of memory for the memory lookup. Shapes: - Query: :math:`(*, d)` - Memory: :math:`(n, d)` - Result: :math:`(*, d)` Examples:: >>> items = torchhd.random(6, 512) >>> read = memory.dense_associative(items, items).sign() >>> torchhd.cosine_similarity(read, items) tensor([[ 1.0000, 0.0469, -0.0117, 0.0039, -0.0313, -0.0078], [ 0.0469, 1.0000, -0.0352, -0.0039, -0.0391, -0.0078], [-0.0117, -0.0352, 1.0000, 0.0547, 0.0742, -0.0352], [ 0.0039, -0.0039, 0.0547, 1.0000, 0.0273, 0.0117], [-0.0313, -0.0391, 0.0742, 0.0273, 1.0000, -0.0547], [-0.0078, -0.0078, -0.0352, 0.0117, -0.0547, 1.0000]]) """ d = query.size(-1) query = query.unsqueeze(-2) repeat = [1 for _ in range(query.dim())] repeat[-2] = d pos_query = query.repeat(*repeat) torch.diagonal(pos_query, dim1=-2, dim2=-1).fill_(1) neg_query = query.repeat(*repeat) torch.diagonal(neg_query, dim1=-2, dim2=-1).fill_(-1) pos_energy = pos_query @ memory.T pos_energy = torch.logsumexp(pos_energy, dim=-1) neg_energy = neg_query @ memory.T neg_energy = torch.logsumexp(neg_energy, dim=-1) return pos_energy - neg_energy
[docs] def attention( query: Tensor, keys: Tensor, values: Tensor, beta: Optional[float] = None ) -> Tensor: r"""`Attention mechanism <https://arxiv.org/abs/1706.03762>`_ Args: query (Tensor): The query vector to compare the similarity with the keys. keys (Tensor): The key vectors to compare with the query. values (Tensor): The value vectors containing retrievable values from memory. beta (float, optional): Temperature scalar for the attention weights before the softmax. Default: 1/sqrt(d) Shapes: - Query: :math:`(*, f)` - Keys: :math:`(n, f)` - Values: :math:`(n, g)` - Result: :math:`(*, g)` Examples:: >>> items = torchhd.random(6, 512) >>> read = torchhd.memory.attention(items, items, items).sign() >>> torchhd.cosine_similarity(read, items) tensor([[ 1.0000, 0.0625, 0.0117, -0.0625, -0.0078, -0.0430], [ 0.0625, 1.0000, -0.0195, 0.0703, 0.0469, 0.0508], [ 0.0117, -0.0195, 1.0000, 0.0820, 0.0195, 0.0156], [-0.0625, 0.0703, 0.0820, 1.0000, -0.0547, -0.0195], [-0.0078, 0.0469, 0.0195, -0.0547, 1.0000, -0.0898], [-0.0430, 0.0508, 0.0156, -0.0195, -0.0898, 1.0000]]) """ if beta is None: d = query.size(-1) beta = 1 / math.sqrt(d) similarity = query @ keys.T scores = torch.softmax(beta * similarity, dim=-1) return scores @ values