Make grad_fn serializable#2871
Conversation
|
For testing, you can add examples in the serde test helpers and add |
|
@sukhadj This has a few small conflicts with master. Could you resolve them so we can get the CI tests running? |
LaRiffle
left a comment
There was a problem hiding this comment.
Excellent starting point!
| setuptools>=41.0.0 | ||
| sphinx-markdown-builder>=0.4.1 | ||
| Sphinx>=2.1.1 No newline at end of file | ||
| Sphinx>=2.1.1 |
- Change Autograd.__init__() to assign grad_fn - Make simplify and detail of GradFunc lighter
|
Most of the CI test failures are due to a version mismatch between |
|
@karlhigley I will add tests for |
|
@karlhigley @LaRiffle Should we add tests for each gradient function? |
|
@sukhadj I don’t mean to hassle you about tests, I’m just excited about this one! 😅 I’d pick a gradient function that exercises as much of the code as possible and use that as your test example. |
|
@karlhigley I believe we currently have some problems mul and pow. So will go ahead with simple |
|
Hey @karlhigley @LaRiffle , the tests are failing due to inconsistent serialization (and one because of syntax error 😅 ). |
- Update additive sharining simplify to have consistent garbage_collect_data - Fix serde test for GradFunc - Remove attributes check temporarily
|
There seems to be one more problem here. |
|
It's interesting that the round-trip serialization test passes, but the simplification test fails. Makes me wonder if the simplification test assertion has the correct expected value. 🤔 |
Hey, so currently comparison function we employ just compares if class name matches but not the attributes. I'm working on comparison of _attributes. Not sure if we should just check equality of list 😃 |
- Revert AdditiveSharingTensor simplify back to master - Update gradFunc test with a dirty fix
|
Hey @LaRiffle @karlhigley need review on this |
karlhigley
left a comment
There was a problem hiding this comment.
Looks good on the whole once the test coverage issue gets resolved.
test/serde/serde_helpers.py
Outdated
| try: | ||
| # either its a tensor | ||
| assert detailed_attr.get().equal(t) | ||
| except AttributeError: |
There was a problem hiding this comment.
This except clause is a good idea but doesn't get exercised by the test suite, which causes a CI test failure due to insufficient coverage. Since it isn't used yet, probably okay to remove/comment it and leave a note that compare only works for tensors.
| for _, share in shares.items(): | ||
| share.garbage_collect_data = value | ||
|
|
||
| def get_garbage_collect_data(self): |
There was a problem hiding this comment.
This seems like a good addition, but doesn't look like it actually gets called anywhere yet. Did I miss it?
There was a problem hiding this comment.
It was used in how previously garbage_collect_data was handled.
|
|
||
| cls = getattr(gradients, cls) | ||
|
|
||
| return cls(*grad_fn_attrs) |
There was a problem hiding this comment.
Tiny
if cls == "GradFunc":
cls = GradFunc
else:
cls = getattr(gradients, cls)
return cls(*grad_fn_attrs)|
|
||
| self.grad_fn = None | ||
| if "grad_fn" in kwargs.keys(): | ||
| self.grad_fn = kwargs["grad_fn"] |
There was a problem hiding this comment.
I think you can do
self.grad_fn = kwargs.get("grad_fn")
|
🎉 |
Solves #2797
This PR makes grad_fn serializable.