Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial jax+dask example. #158

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Initial jax+dask example. #158

wants to merge 7 commits into from

Conversation

asmith26
Copy link

@asmith26 asmith26 commented Jul 11, 2020

This notebook example is a learning exercise during the Scipy2020 Dask sprint to establish how dask might be used to parallelize jax/dm-haiku deep learning model training and prediction.

I've committed my notebook that is working end-to-end, and demonstrates a neural network for learning the sine function.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

Review Jupyter notebook visual diffs & provide feedback on notebooks.


Powered by ReviewNB

@asmith26 asmith26 marked this pull request as ready for review July 22, 2020 17:26
" df_one_partition = ddf_one_partition.compute()\n",
" scaled_x = jnp.array(df_one_partition[[\"scaled_x\"]].values)\n",
" y = jnp.array(df_one_partition[[\"y\"]].values)\n",
" params, opt_state = update(params, opt_state, scaled_x, y)"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth taking a look at some of the functionality in dask-ml, which might do some of these things for you already if you're interested.

cc'ing @stsievert and @TomAugspurger

" futures = []\n",
" for ddf_one_partition in ddf_train.partitions:\n",
" # Compute the gradients in parallel\n",
" futures.append(client.submit(dask_compute_grads_one_partition_wrapper, ddf_one_partition, params))\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend instead ...

from dask.distributed import futures_of
futures = futures_of(df.map_partitions(func, **params).persist())

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've tried this but .map_partitions() requires you to return either a Dask.DataFrame or Dask.Series (I think?). My function returns a set of gradients, grads, which is a Python dictionary (with more python dicts inside, i.e. a tree-like structure), so I don't think this will work in this case (please correct me if I am mistaken).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably work around that with to_delayed() instead of map_partitions. I can take a closer look later.

" # Bring the gradients back to the client, and update the model with the optimizer on the client\n",
" grads = future.result()\n",
" updates, opt_state = optimizer.update(grads, opt_state)\n",
" params = optix.apply_updates(params, updates)"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also the kind of thing for which Actors is probably a decent fit.

Copy link
Author

@asmith26 asmith26 Jul 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've been trying to think how to perform training with shared parameters (and optimizer state) among workers via Actors. Haven't quite got my head around how this might work yet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That example doesn't run, maybe a bad merge. I've put in a PR to correct that: dask/dask#6449

Base automatically changed from master to main January 27, 2021 16:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants