معرفی شرکت ها


diffusionjax-0.0.7


Card image cap
تبلیغات ما

مشتریان به طور فزاینده ای آنلاین هستند. تبلیغات می تواند به آنها کمک کند تا کسب و کار شما را پیدا کنند.

مشاهده بیشتر
Card image cap
تبلیغات ما

مشتریان به طور فزاینده ای آنلاین هستند. تبلیغات می تواند به آنها کمک کند تا کسب و کار شما را پیدا کنند.

مشاهده بیشتر
Card image cap
تبلیغات ما

مشتریان به طور فزاینده ای آنلاین هستند. تبلیغات می تواند به آنها کمک کند تا کسب و کار شما را پیدا کنند.

مشاهده بیشتر
Card image cap
تبلیغات ما

مشتریان به طور فزاینده ای آنلاین هستند. تبلیغات می تواند به آنها کمک کند تا کسب و کار شما را پیدا کنند.

مشاهده بیشتر
Card image cap
تبلیغات ما

مشتریان به طور فزاینده ای آنلاین هستند. تبلیغات می تواند به آنها کمک کند تا کسب و کار شما را پیدا کنند.

مشاهده بیشتر

توضیحات

diffusionjax is a simple and accessible diffusion models package in JAX
ویژگی مقدار
سیستم عامل -
نام فایل diffusionjax-0.0.7
نام diffusionjax
نسخه کتابخانه 0.0.7
نگهدارنده []
ایمیل نگهدارنده []
نویسنده Jakiw Pidstrigach and Benjamin Boys
ایمیل نویسنده -
آدرس صفحه اصلی https://github.com/bb515/diffusionjax
آدرس اینترنتی https://pypi.org/project/diffusionjax/
مجوز MIT
diffusionjax ============ [![CI](https://github.com/bb515/diffusionjax/actions/workflows/CI.yml/badge.svg)](https://github.com/bb515/diffusionjax/actions/workflows/CI.yml) [![Coverage Status](https://coveralls.io/repos/github/bb515/diffusionjax/badge.svg?branch=master)](https://coveralls.io/github/bb515/diffusionjax?branch=master) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) diffusionjax is a simple, accessible introduction to diffusion models, also known as score-based generative models (SGMs). It is implemented in Python via the autodiff framework, [JAX](https://github.com/google/jax). In particular, diffusionjax uses the [Flax](https://github.com/google/flax) library for the neural network approximator of the score. Based off the [Jupyter notebook](https://jakiw.com/sgm_intro) by Jakiw Pidstrigach, a tutorial on the theoretical and implementation aspects of diffusion models. The development of diffusionjax has been supported by The Alan Turing Institute through the Theory and Methods Challenge Fortnights event "Accelerating generative models and nonconvex optimisation", which took place on 6-10 June 2022 and 5-9 Sep 2022 at The Alan Turing Institute headquarters. Contents: - [Installation](#installation) - [Examples](#examples) - [Introduction to diffusion models](#introduction-to-diffusion-models) - [Does haves](#does-haves) - [Doesn't haves](#doesn't-haves) - [References](#references) - [Acknowledgements](#acknowledgements) ## Installation The package requires Python 3.8+. First, it is recommended to [create a new python virtual environment](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-with-commands). Then, `pip install diffusionjax` or for developers, - Clone the repository `git clone git@github.com:bb515/diffusionjax.git` - Install using pip `pip install -e .` from the root directory of the repository (see the `setup.py` for the requirements that this command installs). ## Examples ### Introduction to diffusion models - Run the example by typing `python examples/example.py` on the command line from the root directory of the repository. ```python >>> num_epochs = 4000 >>> rng = random.PRNGKey(2023) >>> rng, step_rng = random.split(rng, 2) >>> num_samples = 8 >>> samples = sample_circle(num_samples) >>> N = samples.shape[1] >>> plot_samples(samples=samples, index=(0, 1), fname="samples", lims=((-3, 3), (-3, 3))) ``` ![Prediction](readme_samples.png) ```python >>> # Get variance preserving (VP) a.k.a. time-changed Ohrnstein Uhlenbeck (OU) sde model >>> sde = VP() >>> >>> def log_hat_pt(x, t): >>> """ >>> Empirical distribution score. >>> >>> Args: >>> x: One location in $\mathbb{R}^2$ >>> t: time >>> Returns: >>> The empirical log density, as described in the Jupyter notebook >>> .. math:: >>> \hat{p}_{t}(x) >>> """ >>> mean, std = sde.marginal_prob(samples, t) >>> potentials = jnp.sum(-(x - mean)**2 / (2 * std**2), axis=1) >>> return logsumexp(potentials, axis=0, b=1/num_samples) >>> >>> # Get a jax grad function, which can be batched with vmap >>> nabla_log_hat_pt = jit(vmap(grad(log_hat_pt), in_axes=(0, 0), out_axes=(0))) >>> >>> # Running the reverse SDE with the empirical drift >>> plot_score(score=nabla_log_hat_pt, t=0.01, area_min=-3, area_max=3, fname="empirical score") ``` ![Prediction](readme_empirical_score.png) ```python >>> sampler = get_sampler(EulerMaruyama(sde.reverse(nabla_log_hat_pt))) >>> q_samples = sampler(rng, n_samples=5000, shape=(N,)) >>> plot_heatmap(samples=q_samples[:, [0, 1]], area_min=-3, area_max=3, fname="heatmap empirical score") ``` ![Prediction](readme_heatmap_empirical_score.png) ```python >>> # What happens when I perturb the score with a constant? >>> perturbed_score = lambda x, t: nabla_log_hat_pt(x, t) + 1 >>> rng, step_rng = random.split(rng) >>> sampler = get_sampler(EulerMaruyama(sde.reverse(perturbed_score))) >>> q_samples = sampler(rng, n_samples=5000, shape=(N,)) >>> plot_heatmap(samples=q_samples[:, [0, 1]], area_min=-3, area_max=3, fname="heatmap bounded perturbation") ``` ![Prediction](readme_heatmap_bounded_perturbation.png) ```python >>> # Neural network training via score matching >>> batch_size=16 >>> score_model = MLP() >>> # Initialize parameters >>> params = score_model.init(step_rng, jnp.zeros((batch_size, N)), jnp.ones((batch_size,))) >>> # Initialize optimizer >>> opt_state = optimizer.init(params) >>> # Get loss function >>> solver = EulerMaruyama(sde) >>> loss = get_loss( >>> sde, solver, score_model, score_scaling=True, likelihood_weighting=False) >>> # Train with score matching >>> score_model, params, opt_state, mean_losses = retrain_nn( >>> update_step=update_step, >>> num_epochs=num_epochs, >>> step_rng=step_rng, >>> samples=samples, >>> score_model=score_model, >>> params=params, >>> opt_state=opt_state, >>> loss=loss, >>> batch_size=batch_size) >>> # Get trained score >>> trained_score = get_score(sde, score_model, params, score_scaling=True) >>> plot_score(score=trained_score, t=0.01, area_min=-3, area_max=3, fname="trained score") ``` ![Prediction](readme_trained_score.png) ```python >>> solver = EulerMaruyama(sde.reverse(trained_score)) >>> sampler = get_sampler(solver, stack_samples=False) >>> q_samples = sampler(rng, n_samples=1000, shape=(N,)) >>> plot_heatmap(samples=q_samples[:, [0, 1]], area_min=-3, area_max=3, fname="heatmap trained score") ``` ![Prediction](readme_heatmap_trained_score.png) ```python >>> inpainter = get_inpainter(solver, stack_samples=False) >>> data = jnp.array([-0.5, 0.0]) >>> mask = jnp.array([1, 0]) >>> data = jnp.tile(data, (64, 1)) >>> mask = jnp.tile(mask, (64, 1)) >>> q_samples = inpainter(rng, data, mask) >>> plot_heatmap(samples=q_samples[:, [0, 1]], area_min=-3, area_max=3, fname="heatmap inpainted") ``` ![Prediction](readme_heatmap_inpainted.png) ## Does haves - Training scores on (possibly, image) data and sampling from the generative model. Also inverse problems, such as inpainting. - Not many lines of code. - Easy to use, extendable. Get started with the example, provided. ## Doesn't haves - Geometry other than Euclidean space, such as Riemannian manifolds. - Diffusion in a latent space. - Augmented with critically-damped Langevin diffusion. ## References Algorithms in this package were ported from pre-existing code. In particular, the code was ported from the following papers and repositories: The [official implementation](https://github.com/yang-song/score_sde) for the paper [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) by [Yang Song](https://yang-song.github.io), [Jascha Sohl-Dickstein](http://www.sohldickstein.com/), [Diederik P. Kingma](http://dpkingma.com/), [Abhishek Kumar](http://users.umiacs.umd.edu/~abhishek/), [Stefano Ermon](https://cs.stanford.edu/~ermon/), and [Ben Poole](https://cs.stanford.edu/~poole/) ## Acknowledgements Thank you to [nPlan](https://www.nplan.io/), who are supporting this project.


زبان مورد نیاز

مقدار نام
>=3.8 Python


نحوه نصب


نصب پکیج whl diffusionjax-0.0.7:

    pip install diffusionjax-0.0.7.whl


نصب پکیج tar.gz diffusionjax-0.0.7:

    pip install diffusionjax-0.0.7.tar.gz